From c4f3f12d26dd1347094aac04e6397c0a1614679e Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Wed, 29 Oct 2025 22:55:55 +0000 Subject: [PATCH 1/4] init Signed-off-by: Sage Moore --- vllm/v1/worker/gpu_model_runner.py | 54 ++++++++++++++++++---------- vllm/v1/worker/gpu_ubatch_wrapper.py | 11 ------ 2 files changed, 35 insertions(+), 30 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e350988456f1..1af00a4e597b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3886,7 +3886,7 @@ def _capture_cudagraphs( # cudagraph, a uniform decode batch, and the number of tokens # is above the threshold. Otherwise we just capture a non-ubatched # version of the graph - allow_microbatching = ( + microbatching_enabled = ( self.parallel_config.enable_dbo and cudagraph_runtime_mode == CUDAGraphMode.FULL and uniform_decode @@ -3897,32 +3897,48 @@ def _capture_cudagraphs( ) ) - for _ in range(self.compilation_config.cudagraph_num_of_warmups): - # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. - # But be careful, warm up with `NONE`is orthogonal to - # if we want to warm up attention or not. This is - # different from the case where `FULL` implies capture - # attention while `PIECEWISE` implies no attention. - force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL + # When num_tokens is near the dbo_decode_token_threshold, different ranks + # may make different microbatching decisions (some above threshold, some + # below). Since all ranks must agree for DBO to work, they'll all fall back + # to non-DBO execution. To avoid running without cudagraphs in these mixed + # cases, we preemptively compile cudagraphs for both microbatching modes. + microbatching_modes = [microbatching_enabled] + compile_both_modes = ( + num_tokens <= self.parallel_config.dbo_decode_token_threshold * 1.5 + ) + if microbatching_enabled and compile_both_modes: + microbatching_modes = [False, True] # Compile both modes + + for allow_microbatching in microbatching_modes: + for _ in range(self.compilation_config.cudagraph_num_of_warmups): + # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. + # But be careful, warm up with `NONE`is orthogonal to + # if we want to warm up attention or not. This is + # different from the case where `FULL` implies capture + # attention while `PIECEWISE` implies no attention. + force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL + self._dummy_run( + num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + activate_lora=activate_lora, + ) + logger.info( + "MAKING GRAPH FOR SHAPE %s %s", allow_microbatching, num_tokens + ) self._dummy_run( num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - force_attention=force_attention, + cudagraph_runtime_mode=cudagraph_runtime_mode, uniform_decode=uniform_decode, allow_microbatching=allow_microbatching, skip_eplb=True, remove_lora=False, activate_lora=activate_lora, ) - self._dummy_run( - num_tokens, - cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=uniform_decode, - allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False, - activate_lora=activate_lora, - ) self.maybe_remove_all_loras(self.lora_config) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 9de123263755..f8c410eb3619 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -380,17 +380,6 @@ def __call__(self, *args, **kwargs): # If there's no ubatching, just run the runnable object if ubatch_slices is None: - # This is to account for the case where ubatching was aborted. - # When we capture full graphs we only capture one graph per shape, - # meaning that if we have a ubatched cudagraph for the current - # num_tokens, we don't have a non-ubatched one. Without this - # check, the cudagraph wrapper will try to capture a cudagraph - # for this shape during a normal run. - if cudagraph_runtime_mode is CUDAGraphMode.FULL: - assert batch_descriptor is not None - if batch_descriptor.num_tokens in self.cudagraphs: - cudagraph_runtime_mode = CUDAGraphMode.NONE - if cudagraph_runtime_mode in (CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE): return self.runnable(*args, **kwargs) else: From e2a97301a3299e7711c9347a0190456bf7142e32 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Wed, 29 Oct 2025 22:56:21 +0000 Subject: [PATCH 2/4] remove log Signed-off-by: Sage Moore --- vllm/v1/worker/gpu_model_runner.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1af00a4e597b..8cf34ad693aa 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3927,9 +3927,6 @@ def _capture_cudagraphs( remove_lora=False, activate_lora=activate_lora, ) - logger.info( - "MAKING GRAPH FOR SHAPE %s %s", allow_microbatching, num_tokens - ) self._dummy_run( num_tokens, cudagraph_runtime_mode=cudagraph_runtime_mode, From 43e9551574c5aa670f362ec9260460e067223fe2 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Wed, 29 Oct 2025 23:21:42 +0000 Subject: [PATCH 3/4] refactoring Signed-off-by: Sage Moore --- vllm/v1/worker/gpu_model_runner.py | 5 +++-- vllm/v1/worker/gpu_ubatch_wrapper.py | 21 +++++++++++++++++++++ vllm/v1/worker/ubatch_utils.py | 6 ++++++ 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8cf34ad693aa..545118698b1b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -133,6 +133,7 @@ from vllm.v1.worker.ubatch_utils import ( UBatchSlice, UBatchSlices, + check_cudagraph_threshold, check_ubatch_thresholds, ) from vllm.v1.worker.utils import is_residual_scattered_for_sp @@ -3903,8 +3904,8 @@ def _capture_cudagraphs( # to non-DBO execution. To avoid running without cudagraphs in these mixed # cases, we preemptively compile cudagraphs for both microbatching modes. microbatching_modes = [microbatching_enabled] - compile_both_modes = ( - num_tokens <= self.parallel_config.dbo_decode_token_threshold * 1.5 + compile_both_modes = check_cudagraph_threshold( + self.parallel_config, num_tokens, uniform_decode ) if microbatching_enabled and compile_both_modes: microbatching_modes = [False, True] # Compile both modes diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index f8c410eb3619..e2d372da4124 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -23,6 +23,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.import_utils import has_deep_gemm +from vllm.v1.worker.ubatch_utils import check_cudagraph_threshold from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts logger = init_logger(__name__) @@ -380,6 +381,26 @@ def __call__(self, *args, **kwargs): # If there's no ubatching, just run the runnable object if ubatch_slices is None: + # This is to account for the case where ubatching was aborted. + # When we capture full graphs we only capture one graph per shape, + # meaning that if we have a ubatched cudagraph for the current + # num_tokens, we don't have a non-ubatched one. Without this + # check, the cudagraph wrapper will try to capture a cudagraph + # for this shape during a normal run. + if cudagraph_runtime_mode is CUDAGraphMode.FULL: + assert batch_descriptor is not None + num_tokens = batch_descriptor.num_tokens + uniform_decode = ( + cudagraph_runtime_mode.decode_mode() == CUDAGraphMode.FULL + and cudagraph_runtime_mode.separate_routine() + ) + # Check if the model runner made a non-dbo cudagraph for this shape + cudagraph_exists = check_cudagraph_threshold( + self.vllm_config.parallel_config, num_tokens, uniform_decode + ) + if num_tokens in self.cudagraphs and not cudagraph_exists: + cudagraph_runtime_mode = CUDAGraphMode.NONE + if cudagraph_runtime_mode in (CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE): return self.runnable(*args, **kwargs) else: diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index 33a1921d2d98..b97510984975 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -42,6 +42,12 @@ def check_ubatch_thresholds( return num_tokens >= config.dbo_prefill_token_threshold +def check_cudagraph_threshold( + config: ParallelConfig, num_tokens: int, uniform_decode: bool +): + return uniform_decode and num_tokens > config.dbo_decode_token_threshold * 1.5 + + def create_ubatch_slices( num_scheduled_tokens: np.ndarray, split_point: int ) -> UBatchSlices: From ea8936319ba281022739d24e9956b0447902497d Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 30 Oct 2025 18:39:36 +0000 Subject: [PATCH 4/4] misc fixes Signed-off-by: Sage Moore --- vllm/v1/worker/gpu_ubatch_wrapper.py | 5 +---- vllm/v1/worker/ubatch_utils.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index e2d372da4124..7614c21eaf3e 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -390,10 +390,7 @@ def __call__(self, *args, **kwargs): if cudagraph_runtime_mode is CUDAGraphMode.FULL: assert batch_descriptor is not None num_tokens = batch_descriptor.num_tokens - uniform_decode = ( - cudagraph_runtime_mode.decode_mode() == CUDAGraphMode.FULL - and cudagraph_runtime_mode.separate_routine() - ) + uniform_decode = batch_descriptor.uniform_decode # Check if the model runner made a non-dbo cudagraph for this shape cudagraph_exists = check_cudagraph_threshold( self.vllm_config.parallel_config, num_tokens, uniform_decode diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index b97510984975..fce91e532c25 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -45,7 +45,7 @@ def check_ubatch_thresholds( def check_cudagraph_threshold( config: ParallelConfig, num_tokens: int, uniform_decode: bool ): - return uniform_decode and num_tokens > config.dbo_decode_token_threshold * 1.5 + return uniform_decode and num_tokens <= config.dbo_decode_token_threshold * 1.5 def create_ubatch_slices(