From 2890709e930e72ea1ad258a1011296584f5966c3 Mon Sep 17 00:00:00 2001 From: Teodor-Dumitru Ene Date: Thu, 26 Feb 2026 13:16:33 -0600 Subject: [PATCH 1/8] Fix illegal memory access with mamba inference --- megatron/core/inference/contexts/dynamic_context.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index e9e033c9eee..ef2a6aaa5cb 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -609,9 +609,6 @@ def _allocate_memory_buffer(self): def _allocate_mamba_states(self): """Allocate Mamba states for hybrid models.""" if self.is_hybrid_model: - self.mamba_metadata = MambaMetadata( - max_requests=self.max_requests, max_tokens=self.max_tokens - ) self.mamba_conv_states = torch.empty( (self.num_mamba_layers, self.max_requests) + self.mamba_conv_states_shape, dtype=self.params_dtype, @@ -695,6 +692,11 @@ def initialize_all_tensors(self) -> None: self.token_to_position_in_request = torch.empty_like(self.token_to_input_ids) self.token_to_local_position_within_kv_block = torch.empty_like(self.token_to_input_ids) + if self.is_hybrid_model: + self.mamba_metadata = MambaMetadata( + max_requests=self.max_requests, max_tokens=self.max_tokens + ) + # Allocate large non-graphed buffers. need_static_addr = ( self.static_kv_memory_pointers @@ -731,10 +733,10 @@ def reinitialize_inference_state_buffers(self): return if self.unified_memory_level != 0 or self._uses_torch_memory_saver: - if self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: - self.reset() if self._uses_torch_memory_saver: torch_memory_saver.resume("inference_context") + if self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: + self.reset() return if self.kv_cache_management_mode == KVCacheManagementMode.OFFLOAD: From 7ffe167866b9767612cf5c4033026683fbf2f9bc Mon Sep 17 00:00:00 2001 From: Teodor-Dumitru Ene Date: Fri, 27 Feb 2026 11:36:49 -0600 Subject: [PATCH 2/8] Address reviewer comments --- megatron/core/inference/contexts/dynamic_context.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index ef2a6aaa5cb..fe50d152ca6 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -692,6 +692,7 @@ def initialize_all_tensors(self) -> None: self.token_to_position_in_request = torch.empty_like(self.token_to_input_ids) self.token_to_local_position_within_kv_block = torch.empty_like(self.token_to_input_ids) + # NOTE: Need to build this outside the UVM / TMS context to avoid IMA. if self.is_hybrid_model: self.mamba_metadata = MambaMetadata( max_requests=self.max_requests, max_tokens=self.max_tokens @@ -733,6 +734,7 @@ def reinitialize_inference_state_buffers(self): return if self.unified_memory_level != 0 or self._uses_torch_memory_saver: + # Need to bring back the memory block before we reset it. if self._uses_torch_memory_saver: torch_memory_saver.resume("inference_context") if self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: From 2021b8455f6de7728580da257bfce88136470a7a Mon Sep 17 00:00:00 2001 From: Teodor-Dumitru Ene Date: Mon, 2 Mar 2026 03:24:44 -0600 Subject: [PATCH 3/8] Fix no-persist-CG case as well --- .../inference/contexts/dynamic_context.py | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index e58683d3516..ea5bb60d8e5 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -736,21 +736,24 @@ def reinitialize_inference_state_buffers(self): if self.kv_cache_management_mode == KVCacheManagementMode.PERSIST: return - if self.unified_memory_level != 0 or self._uses_torch_memory_saver: - # Need to bring back the memory block before we reset it. - if self._uses_torch_memory_saver: - torch_memory_saver.resume("inference_context") - if self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: - self.reset() - return - - if self.kv_cache_management_mode == KVCacheManagementMode.OFFLOAD: - for name, tensor in ((n, getattr(self, n)) for n in self._offloadable_tensor_names): - tensor.storage().resize_(self._offloadable_storage_sizes[name]) - tensor.copy_(self._offloadable_cpu_backups[name], non_blocking=True) - elif self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: - self.is_tensor_state_allocated = False - self.initialize_all_tensors() + if self._uses_torch_memory_saver: + # Need to bring back the memory block before we do anything else. + torch_memory_saver.resume("inference_context") + elif self.unified_memory_level != 0: + # Include this for code readability. + pass + else: + if self.kv_cache_management_mode == KVCacheManagementMode.OFFLOAD: + for name, tensor in ((n, getattr(self, n)) for n in self._offloadable_tensor_names): + tensor.storage().resize_(self._offloadable_storage_sizes[name]) + tensor.copy_(self._offloadable_cpu_backups[name], non_blocking=True) + elif self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: + self.is_tensor_state_allocated = False + self.initialize_all_tensors() + + # No matter which memory mode, in RECOMPUTE we need to reset the context. + if self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: + self.reset() def deallocate_inference_state_buffers(self): """Deallocate large tensors (KV cache, Mamba states) during suspend. From cc85fcdb182b6527ca3db25949b62af7c666cc4f Mon Sep 17 00:00:00 2001 From: Teodor-Dumitru Ene Date: Mon, 2 Mar 2026 05:57:23 -0600 Subject: [PATCH 4/8] Alternative implementation to the reset issue --- .../inference/contexts/dynamic_context.py | 38 +++++++++---------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index ea5bb60d8e5..aabe3b58323 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -720,9 +720,8 @@ def initialize_all_tensors(self) -> None: self._allocate_memory_buffer() self._allocate_mamba_states() - # Reset attention and Mamba state. - self.reset_attention_state() - self.reset_mamba_state() + # Reset entire context state. + self.reset() def reinitialize_inference_state_buffers(self): """Restore large tensors (KV cache, Mamba states) after a suspend. @@ -736,24 +735,21 @@ def reinitialize_inference_state_buffers(self): if self.kv_cache_management_mode == KVCacheManagementMode.PERSIST: return - if self._uses_torch_memory_saver: - # Need to bring back the memory block before we do anything else. - torch_memory_saver.resume("inference_context") - elif self.unified_memory_level != 0: - # Include this for code readability. - pass - else: - if self.kv_cache_management_mode == KVCacheManagementMode.OFFLOAD: - for name, tensor in ((n, getattr(self, n)) for n in self._offloadable_tensor_names): - tensor.storage().resize_(self._offloadable_storage_sizes[name]) - tensor.copy_(self._offloadable_cpu_backups[name], non_blocking=True) - elif self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: - self.is_tensor_state_allocated = False - self.initialize_all_tensors() - - # No matter which memory mode, in RECOMPUTE we need to reset the context. - if self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: - self.reset() + if self.unified_memory_level != 0 or self._uses_torch_memory_saver: + # Need to bring back the memory block before we reset it. + if self._uses_torch_memory_saver: + torch_memory_saver.resume("inference_context") + if self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: + self.reset() + return + + if self.kv_cache_management_mode == KVCacheManagementMode.OFFLOAD: + for name, tensor in ((n, getattr(self, n)) for n in self._offloadable_tensor_names): + tensor.storage().resize_(self._offloadable_storage_sizes[name]) + tensor.copy_(self._offloadable_cpu_backups[name], non_blocking=True) + elif self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: + self.is_tensor_state_allocated = False + self.initialize_all_tensors() def deallocate_inference_state_buffers(self): """Deallocate large tensors (KV cache, Mamba states) during suspend. From 68f8eb32fa42723c83fab54b3f52e96288796af2 Mon Sep 17 00:00:00 2001 From: Teodor-Dumitru Ene Date: Mon, 2 Mar 2026 10:53:12 -0600 Subject: [PATCH 5/8] Split up reset upon reviewer comment --- .../inference/contexts/dynamic_context.py | 88 +++++++++---------- 1 file changed, 42 insertions(+), 46 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index aabe3b58323..01eec0d85c6 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -454,20 +454,6 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC self.params_dtype = model_config.params_dtype self.max_sequence_length = inference_config.max_sequence_length - # Request and token counts. - self.total_request_count = 0 - self.active_token_count = 0 - self.paused_request_count = 0 - self.batch_dimensions = InferenceBatchDimensions( - token_count=0, prefill_req_count=0, decode_req_count=0 - ) - self.padded_batch_dimensions = InferenceBatchDimensions( - token_count=0, prefill_req_count=0, decode_req_count=0 - ) - self.padded_active_token_count = 0 - self.padded_active_request_count = 0 - self.paused_tokens = None - # Block ids. self.max_kv_block_count = math.ceil(self.max_sequence_length / self.block_size_tokens) @@ -500,11 +486,8 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC # Attention metadata initialization (tensors are now handled by MHAMetadata classes) - self.num_prefill_requests = 0 self.graph_attn_metadata = {} self.non_graph_attn_metadata = {} - self.active_attn_metadata = None - self.is_creating_cuda_graphs = False self.graph_attn_metadata["mha_metadata"] = GraphedMHAMetadata( block_count_total=self.block_allocator.total_count, @@ -547,10 +530,8 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC ) self.cuda_graph_mixed_prefill_count = inference_config.cuda_graph_mixed_prefill_count - self._using_cuda_graph_this_step = False # Deal with chunked prefill self.enable_chunked_prefill = inference_config.enable_chunked_prefill - self.chunked_prefill_request_id = -1 # FlashInfer. if inference_config.use_flashinfer_fused_rope is True: @@ -563,6 +544,7 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC self.is_tensor_state_allocated = False self.is_symmetric_memory_initialized = False self.initialize_all_tensors() + self.reset_metadata() # Print info. logging.info( @@ -720,9 +702,6 @@ def initialize_all_tensors(self) -> None: self._allocate_memory_buffer() self._allocate_mamba_states() - # Reset entire context state. - self.reset() - def reinitialize_inference_state_buffers(self): """Restore large tensors (KV cache, Mamba states) after a suspend. @@ -740,7 +719,7 @@ def reinitialize_inference_state_buffers(self): if self._uses_torch_memory_saver: torch_memory_saver.resume("inference_context") if self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: - self.reset() + self.reset_metadata() return if self.kv_cache_management_mode == KVCacheManagementMode.OFFLOAD: @@ -750,6 +729,7 @@ def reinitialize_inference_state_buffers(self): elif self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: self.is_tensor_state_allocated = False self.initialize_all_tensors() + self.reset_metadata() def deallocate_inference_state_buffers(self): """Deallocate large tensors (KV cache, Mamba states) during suspend. @@ -1455,30 +1435,13 @@ def initialize_attention_state( else: self.moe_routing_metadata.disable_static_buffer_recording() - def reset(self) -> None: - """Reset entire context. - - This method does: - - Reset active/paused request/token counts to zero. - - Reset available blocks to entire memory. - - Reset other tensors to zeros (unncessary, just or sanity checking). + def reset_tensors(self) -> None: + """Fill all GPU tensors with sentinel values. - This method is useful after cuda graph warmup iterations, where the - context's memory buffer is referenced by the cuda graph system and - cannot be deallocated. + This is unnecessary for correctness (tensors are overwritten before use) + but useful for sanity checking and debugging. """ - # Reset request/token counts. - self.total_request_count = 0 - self.active_token_count = 0 - self.paused_request_count = 0 - self.batch_dimensions = InferenceBatchDimensions( - token_count=0, prefill_req_count=0, decode_req_count=0 - ) - self.padded_active_token_count = 0 - self.padded_active_request_count = 0 - self.paused_tokens = None - # Reset request indexes. self.request_ids.fill_(-1) self.request_query_lengths.fill_(0) @@ -1501,13 +1464,31 @@ def reset(self) -> None: self.token_to_block_idx.fill_(-1) self.token_to_local_position_within_kv_block.fill_(0) - # Reset available block count. + def reset_metadata(self) -> None: + """Reset all bookkeeping state: counters, block allocator, attention/mamba state. + + This must be called after ``initialize_all_tensors()`` and after any + suspend/resume cycle to bring the context back to a clean state. + """ + + # Reset request/token counts. + self.total_request_count = 0 + self.active_token_count = 0 + self.paused_request_count = 0 + self.batch_dimensions = InferenceBatchDimensions( + token_count=0, prefill_req_count=0, decode_req_count=0 + ) + self.padded_active_token_count = 0 + self.padded_active_request_count = 0 + self.paused_tokens = None + + # Reset attention, mamba, and block allocator state. self.reset_attention_state() self.reset_mamba_state() self.block_allocator.reset() self.request_to_kv_block_ids.fill_(-1) - # Reset chunked prefill state + # Reset chunked prefill state. self.chunked_prefill_request_id = -1 self.num_prefill_requests = 0 self._using_cuda_graph_this_step = False @@ -1516,6 +1497,21 @@ def reset(self) -> None: token_count=0, prefill_req_count=0, decode_req_count=0 ) + def reset(self) -> None: + """Reset entire context. + + This method does: + - Fill all GPU tensors with sentinel values. + - Reset active/paused request/token counts to zero. + - Reset available blocks to entire memory. + + This method is useful after cuda graph warmup iterations, where the + context's memory buffer is referenced by the cuda graph system and + cannot be deallocated. + """ + self.reset_tensors() + self.reset_metadata() + def current_input_and_position_ids( self, *, num_warmup_tokens: Optional[int] = None ) -> Tuple[Tensor, Tensor]: From a86c6931423cc914fd384e90dd0c65ec8c76fcb5 Mon Sep 17 00:00:00 2001 From: Teodor-Dumitru Ene Date: Mon, 2 Mar 2026 10:58:18 -0600 Subject: [PATCH 6/8] Make if statements easier to read --- .../inference/contexts/dynamic_context.py | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 01eec0d85c6..0bbb79fdfa0 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -714,23 +714,26 @@ def reinitialize_inference_state_buffers(self): if self.kv_cache_management_mode == KVCacheManagementMode.PERSIST: return - if self.unified_memory_level != 0 or self._uses_torch_memory_saver: + if self._uses_torch_memory_saver: # Need to bring back the memory block before we reset it. - if self._uses_torch_memory_saver: - torch_memory_saver.resume("inference_context") - if self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: - self.reset_metadata() - return - - if self.kv_cache_management_mode == KVCacheManagementMode.OFFLOAD: - for name, tensor in ((n, getattr(self, n)) for n in self._offloadable_tensor_names): - tensor.storage().resize_(self._offloadable_storage_sizes[name]) - tensor.copy_(self._offloadable_cpu_backups[name], non_blocking=True) - elif self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: - self.is_tensor_state_allocated = False - self.initialize_all_tensors() + torch_memory_saver.resume("inference_context") + elif self.unified_memory_level != 0: + # Include this for code readability. + pass + else: + if self.kv_cache_management_mode == KVCacheManagementMode.OFFLOAD: + for name, tensor in ((n, getattr(self, n)) for n in self._offloadable_tensor_names): + tensor.storage().resize_(self._offloadable_storage_sizes[name]) + tensor.copy_(self._offloadable_cpu_backups[name], non_blocking=True) + elif self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: + self.is_tensor_state_allocated = False + self.initialize_all_tensors() + + # No matter which memory mode, in RECOMPUTE we need to reset the metadata. + if self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: self.reset_metadata() + def deallocate_inference_state_buffers(self): """Deallocate large tensors (KV cache, Mamba states) during suspend. @@ -1436,11 +1439,7 @@ def initialize_attention_state( self.moe_routing_metadata.disable_static_buffer_recording() def reset_tensors(self) -> None: - """Fill all GPU tensors with sentinel values. - - This is unnecessary for correctness (tensors are overwritten before use) - but useful for sanity checking and debugging. - """ + """Fill all GPU tensors with sentinel values.""" # Reset request indexes. self.request_ids.fill_(-1) From 459610946b94bb81ee64233ae28c5bdead6979ff Mon Sep 17 00:00:00 2001 From: Teodor-Dumitru Ene Date: Mon, 2 Mar 2026 19:00:22 -0600 Subject: [PATCH 7/8] lint --- megatron/core/inference/contexts/dynamic_context.py | 1 - 1 file changed, 1 deletion(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 0bbb79fdfa0..9be63c19bf4 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -733,7 +733,6 @@ def reinitialize_inference_state_buffers(self): if self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: self.reset_metadata() - def deallocate_inference_state_buffers(self): """Deallocate large tensors (KV cache, Mamba states) during suspend. From 73f84757cb4c41260178f491ab5a6b0f1b9103e9 Mon Sep 17 00:00:00 2001 From: Teodor-Dumitru Ene Date: Mon, 2 Mar 2026 21:38:58 -0600 Subject: [PATCH 8/8] Address reviewer comments --- .../inference/contexts/dynamic_context.py | 39 ++++++++----------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 9be63c19bf4..f2c6637f0bd 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -544,7 +544,6 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC self.is_tensor_state_allocated = False self.is_symmetric_memory_initialized = False self.initialize_all_tensors() - self.reset_metadata() # Print info. logging.info( @@ -702,6 +701,9 @@ def initialize_all_tensors(self) -> None: self._allocate_memory_buffer() self._allocate_mamba_states() + # Reset tensor-related metadata. + self.reset_metadata() + def reinitialize_inference_state_buffers(self): """Restore large tensors (KV cache, Mamba states) after a suspend. @@ -714,24 +716,21 @@ def reinitialize_inference_state_buffers(self): if self.kv_cache_management_mode == KVCacheManagementMode.PERSIST: return - if self._uses_torch_memory_saver: + if self.unified_memory_level != 0 and self._uses_torch_memory_saver: # Need to bring back the memory block before we reset it. - torch_memory_saver.resume("inference_context") - elif self.unified_memory_level != 0: - # Include this for code readability. - pass - else: - if self.kv_cache_management_mode == KVCacheManagementMode.OFFLOAD: - for name, tensor in ((n, getattr(self, n)) for n in self._offloadable_tensor_names): - tensor.storage().resize_(self._offloadable_storage_sizes[name]) - tensor.copy_(self._offloadable_cpu_backups[name], non_blocking=True) - elif self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: - self.is_tensor_state_allocated = False - self.initialize_all_tensors() - - # No matter which memory mode, in RECOMPUTE we need to reset the metadata. - if self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: - self.reset_metadata() + if self._uses_torch_memory_saver: + torch_memory_saver.resume("inference_context") + if self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: + self.reset_metadata() + return + + if self.kv_cache_management_mode == KVCacheManagementMode.OFFLOAD: + for name, tensor in ((n, getattr(self, n)) for n in self._offloadable_tensor_names): + tensor.storage().resize_(self._offloadable_storage_sizes[name]) + tensor.copy_(self._offloadable_cpu_backups[name], non_blocking=True) + elif self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: + self.is_tensor_state_allocated = False + self.initialize_all_tensors() def deallocate_inference_state_buffers(self): """Deallocate large tensors (KV cache, Mamba states) during suspend. @@ -1502,10 +1501,6 @@ def reset(self) -> None: - Fill all GPU tensors with sentinel values. - Reset active/paused request/token counts to zero. - Reset available blocks to entire memory. - - This method is useful after cuda graph warmup iterations, where the - context's memory buffer is referenced by the cuda graph system and - cannot be deallocated. """ self.reset_tensors() self.reset_metadata()