Skip to content
106 changes: 52 additions & 54 deletions megatron/core/inference/contexts/dynamic_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,21 +467,6 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC
self.params_dtype = model_config.params_dtype
self.max_sequence_length = inference_config.max_sequence_length

# Request and token counts.
self.total_request_count = 0
self.active_token_count = 0
self.lifetime_prefill_token_count = 0
self.paused_request_count = 0
self.batch_dimensions = InferenceBatchDimensions(
token_count=0, prefill_req_count=0, decode_req_count=0
)
self.padded_batch_dimensions = InferenceBatchDimensions(
token_count=0, prefill_req_count=0, decode_req_count=0
)
self.padded_active_token_count = 0
self.padded_active_request_count = 0
self.paused_tokens = None

# Block ids.
self.max_kv_block_count = math.ceil(self.max_sequence_length / self.block_size_tokens)

Expand Down Expand Up @@ -514,11 +499,8 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC

# Attention metadata initialization (tensors are now handled by MHAMetadata classes)

self.num_prefill_requests = 0
self.graph_attn_metadata = {}
self.non_graph_attn_metadata = {}
self.active_attn_metadata = None
self.is_creating_cuda_graphs = False

self.graph_attn_metadata["mha_metadata"] = GraphedMHAMetadata(
block_count_total=self.block_allocator.total_count,
Expand Down Expand Up @@ -564,10 +546,8 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC
inference_config.cuda_graph_mixed_prefill_count, self.max_requests
)

self._using_cuda_graph_this_step = False
# Deal with chunked prefill
self.enable_chunked_prefill = inference_config.enable_chunked_prefill
self.chunked_prefill_request_id = -1

# FlashInfer.
if inference_config.use_flashinfer_fused_rope is True:
Expand Down Expand Up @@ -629,9 +609,6 @@ def _allocate_memory_buffer(self):
def _allocate_mamba_states(self):
"""Allocate Mamba states for hybrid models."""
if self.is_hybrid_model:
self.mamba_metadata = MambaMetadata(
max_requests=self.max_requests, max_tokens=self.max_tokens
)
self.mamba_conv_states = torch.empty(
(self.num_mamba_layers, self.max_requests) + self.mamba_conv_states_shape,
dtype=self.params_dtype,
Expand Down Expand Up @@ -715,6 +692,12 @@ def initialize_all_tensors(self) -> None:
self.token_to_position_in_request = torch.empty_like(self.token_to_input_ids)
self.token_to_local_position_within_kv_block = torch.empty_like(self.token_to_input_ids)

# NOTE: Need to build this outside the UVM / TMS context to avoid IMA.
if self.is_hybrid_model:
self.mamba_metadata = MambaMetadata(
max_requests=self.max_requests, max_tokens=self.max_tokens
)

# Allocate large non-graphed buffers.
need_static_addr = (
self.static_kv_memory_pointers
Expand All @@ -734,9 +717,8 @@ def initialize_all_tensors(self) -> None:
self._allocate_memory_buffer()
self._allocate_mamba_states()

# Reset attention and Mamba state.
self.reset_attention_state()
self.reset_mamba_state()
# Reset tensor-related metadata.
self.reset_metadata()

def reinitialize_inference_state_buffers(self):
"""Restore large tensors (KV cache, Mamba states) after a suspend.
Expand All @@ -750,11 +732,12 @@ def reinitialize_inference_state_buffers(self):
if self.kv_cache_management_mode == KVCacheManagementMode.PERSIST:
return

if self.unified_memory_level != 0 or self._uses_torch_memory_saver:
if self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE:
self.reset()
if self.unified_memory_level != 0 and self._uses_torch_memory_saver:
# Need to bring back the memory block before we reset it.
if self._uses_torch_memory_saver:
torch_memory_saver.resume("inference_context")
if self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE:
self.reset_metadata()
return

if self.kv_cache_management_mode == KVCacheManagementMode.OFFLOAD:
Expand Down Expand Up @@ -1472,30 +1455,8 @@ def initialize_attention_state(
else:
self.moe_routing_metadata.disable_static_buffer_recording()

def reset(self) -> None:
"""Reset entire context.

This method does:
- Reset active/paused request/token counts to zero.
- Reset available blocks to entire memory.
- Reset other tensors to zeros (unncessary, just or sanity checking).

This method is useful after cuda graph warmup iterations, where the
context's memory buffer is referenced by the cuda graph system and
cannot be deallocated.
"""

# Reset request/token counts.
self.total_request_count = 0
self.active_token_count = 0
self.lifetime_prefill_token_count = 0
self.paused_request_count = 0
self.batch_dimensions = InferenceBatchDimensions(
token_count=0, prefill_req_count=0, decode_req_count=0
)
self.padded_active_token_count = 0
self.padded_active_request_count = 0
self.paused_tokens = None
def reset_tensors(self) -> None:
"""Fill all GPU tensors with sentinel values."""

# Reset request indexes.
self.request_ids.fill_(-1)
Expand All @@ -1519,7 +1480,29 @@ def reset(self) -> None:
self.token_to_block_idx.fill_(-1)
self.token_to_local_position_within_kv_block.fill_(0)

# Reset available block count.
def reset_metadata(self) -> None:
"""Reset all bookkeeping state: counters, block allocator, attention/mamba state.

This must be called after ``initialize_all_tensors()`` and after any
suspend/resume cycle to bring the context back to a clean state.
"""

# Reset request/token counts.
self.total_request_count = 0
self.active_token_count = 0
self.lifetime_prefill_token_count = 0
self.paused_request_count = 0
self.batch_dimensions = InferenceBatchDimensions(
token_count=0, prefill_req_count=0, decode_req_count=0
)
self.padded_batch_dimensions = InferenceBatchDimensions(
token_count=0, prefill_req_count=0, decode_req_count=0
)
self.padded_active_token_count = 0
self.padded_active_request_count = 0
self.paused_tokens = None

# Reset attention, mamba, and block allocator state.
self.reset_attention_state()
self.reset_mamba_state()
self.block_allocator.reset()
Expand All @@ -1537,6 +1520,21 @@ def reset(self) -> None:
token_count=0, prefill_req_count=0, decode_req_count=0
)

def reset(self) -> None:
"""Reset entire context.

This method does:
- Fill all GPU tensors with sentinel values.
- Reset active/paused request/token counts to zero.
- Reset available blocks to entire memory.

This method is useful after cuda graph warmup iterations, where the
context's memory buffer is referenced by the cuda graph system and
cannot be deallocated.
"""
self.reset_tensors()
self.reset_metadata()

def current_input_and_position_ids(
self, *, num_warmup_tokens: Optional[int] = None
) -> Tuple[Tensor, Tensor]:
Expand Down
Loading