diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 6cb4a009642..3de6017fa9c 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -469,21 +469,6 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC self.params_dtype = model_config.params_dtype self.max_sequence_length = inference_config.max_sequence_length - # Request and token counts. - self.total_request_count = 0 - self.active_token_count = 0 - self.lifetime_prefill_token_count = 0 - self.paused_request_count = 0 - self.batch_dimensions = InferenceBatchDimensions( - token_count=0, prefill_req_count=0, decode_req_count=0 - ) - self.padded_batch_dimensions = InferenceBatchDimensions( - token_count=0, prefill_req_count=0, decode_req_count=0 - ) - self.padded_active_token_count = 0 - self.padded_active_request_count = 0 - self.paused_tokens = None - # Block ids. self.max_kv_block_count = math.ceil(self.max_sequence_length / self.block_size_tokens) @@ -516,11 +501,8 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC # Attention metadata initialization (tensors are now handled by MHAMetadata classes) - self.num_prefill_requests = 0 self.graph_attn_metadata = {} self.non_graph_attn_metadata = {} - self.active_attn_metadata = None - self.is_creating_cuda_graphs = False self.graph_attn_metadata["mha_metadata"] = GraphedMHAMetadata( block_count_total=self.block_allocator.total_count, @@ -566,10 +548,8 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC inference_config.cuda_graph_mixed_prefill_count, self.max_requests ) - self._using_cuda_graph_this_step = False # Deal with chunked prefill self.enable_chunked_prefill = inference_config.enable_chunked_prefill - self.chunked_prefill_request_id = -1 # FlashInfer. if inference_config.use_flashinfer_fused_rope is True: @@ -631,9 +611,6 @@ def _allocate_memory_buffer(self): def _allocate_mamba_states(self): """Allocate Mamba states for hybrid models.""" if self.is_hybrid_model: - self.mamba_metadata = MambaMetadata( - max_requests=self.max_requests, max_tokens=self.max_tokens - ) self.mamba_conv_states = torch.empty( (self.num_mamba_layers, self.max_requests) + self.mamba_conv_states_shape, dtype=self.params_dtype, @@ -717,6 +694,12 @@ def initialize_all_tensors(self) -> None: self.token_to_position_in_request = torch.empty_like(self.token_to_input_ids) self.token_to_local_position_within_kv_block = torch.empty_like(self.token_to_input_ids) + # NOTE: Need to build this outside the UVM / TMS context to avoid IMA. + if self.is_hybrid_model: + self.mamba_metadata = MambaMetadata( + max_requests=self.max_requests, max_tokens=self.max_tokens + ) + # Allocate large non-graphed buffers. need_static_addr = ( self.static_kv_memory_pointers @@ -736,9 +719,8 @@ def initialize_all_tensors(self) -> None: self._allocate_memory_buffer() self._allocate_mamba_states() - # Reset attention and Mamba state. - self.reset_attention_state() - self.reset_mamba_state() + # Reset tensor-related metadata. + self.reset_metadata() def reinitialize_inference_state_buffers(self): """Restore large tensors (KV cache, Mamba states) after a suspend. @@ -752,9 +734,8 @@ def reinitialize_inference_state_buffers(self): if self.kv_cache_management_mode == KVCacheManagementMode.PERSIST: return - if self.unified_memory_level != 0 or self._uses_torch_memory_saver: - if self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: - self.reset() + if self.unified_memory_level != 0 and self._uses_torch_memory_saver: + # Need to bring back the memory block before we reset it. if self._uses_torch_memory_saver: tag = self.TMS_TAG if torch.distributed.get_rank() == 0: @@ -766,6 +747,8 @@ def reinitialize_inference_state_buffers(self): logging.info( "torch_memory_saver: resumed %s, after: %s", tag, device_memory_summary() ) + if self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: + self.reset_metadata() return if self.kv_cache_management_mode == KVCacheManagementMode.OFFLOAD: @@ -1492,30 +1475,8 @@ def initialize_attention_state( else: self.moe_routing_metadata.disable_static_buffer_recording() - def reset(self) -> None: - """Reset entire context. - - This method does: - - Reset active/paused request/token counts to zero. - - Reset available blocks to entire memory. - - Reset other tensors to zeros (unncessary, just or sanity checking). - - This method is useful after cuda graph warmup iterations, where the - context's memory buffer is referenced by the cuda graph system and - cannot be deallocated. - """ - - # Reset request/token counts. - self.total_request_count = 0 - self.active_token_count = 0 - self.lifetime_prefill_token_count = 0 - self.paused_request_count = 0 - self.batch_dimensions = InferenceBatchDimensions( - token_count=0, prefill_req_count=0, decode_req_count=0 - ) - self.padded_active_token_count = 0 - self.padded_active_request_count = 0 - self.paused_tokens = None + def reset_tensors(self) -> None: + """Fill all GPU tensors with sentinel values.""" # Reset request indexes. self.request_ids.fill_(-1) @@ -1539,7 +1500,29 @@ def reset(self) -> None: self.token_to_block_idx.fill_(-1) self.token_to_local_position_within_kv_block.fill_(0) - # Reset available block count. + def reset_metadata(self) -> None: + """Reset all bookkeeping state: counters, block allocator, attention/mamba state. + + This must be called after ``initialize_all_tensors()`` and after any + suspend/resume cycle to bring the context back to a clean state. + """ + + # Reset request/token counts. + self.total_request_count = 0 + self.active_token_count = 0 + self.lifetime_prefill_token_count = 0 + self.paused_request_count = 0 + self.batch_dimensions = InferenceBatchDimensions( + token_count=0, prefill_req_count=0, decode_req_count=0 + ) + self.padded_batch_dimensions = InferenceBatchDimensions( + token_count=0, prefill_req_count=0, decode_req_count=0 + ) + self.padded_active_token_count = 0 + self.padded_active_request_count = 0 + self.paused_tokens = None + + # Reset attention, mamba, and block allocator state. self.reset_attention_state() self.reset_mamba_state() self.block_allocator.reset() @@ -1557,6 +1540,21 @@ def reset(self) -> None: token_count=0, prefill_req_count=0, decode_req_count=0 ) + def reset(self) -> None: + """Reset entire context. + + This method does: + - Fill all GPU tensors with sentinel values. + - Reset active/paused request/token counts to zero. + - Reset available blocks to entire memory. + + This method is useful after cuda graph warmup iterations, where the + context's memory buffer is referenced by the cuda graph system and + cannot be deallocated. + """ + self.reset_tensors() + self.reset_metadata() + def current_input_and_position_ids( self, *, num_warmup_tokens: Optional[int] = None ) -> Tuple[Tensor, Tensor]: