diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 91015ad4379c..03b14ef80a5d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -136,6 +136,7 @@ from vllm.v1.worker.ubatch_utils import ( UBatchSlice, UBatchSlices, + check_cudagraph_threshold, check_ubatch_thresholds, ) from vllm.v1.worker.utils import is_residual_scattered_for_sp @@ -3980,7 +3981,7 @@ def _capture_cudagraphs( # cudagraph, a uniform decode batch, and the number of tokens # is above the threshold. Otherwise we just capture a non-ubatched # version of the graph - allow_microbatching = ( + microbatching_enabled = ( self.parallel_config.enable_dbo and cudagraph_runtime_mode == CUDAGraphMode.FULL and uniform_decode @@ -3991,32 +3992,45 @@ def _capture_cudagraphs( ) ) - for _ in range(self.compilation_config.cudagraph_num_of_warmups): - # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. - # But be careful, warm up with `NONE`is orthogonal to - # if we want to warm up attention or not. This is - # different from the case where `FULL` implies capture - # attention while `PIECEWISE` implies no attention. - force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL + # When num_tokens is near the dbo_decode_token_threshold, different ranks + # may make different microbatching decisions (some above threshold, some + # below). Since all ranks must agree for DBO to work, they'll all fall back + # to non-DBO execution. To avoid running without cudagraphs in these mixed + # cases, we preemptively compile cudagraphs for both microbatching modes. + microbatching_modes = [microbatching_enabled] + compile_both_modes = check_cudagraph_threshold( + self.parallel_config, num_tokens, uniform_decode + ) + if microbatching_enabled and compile_both_modes: + microbatching_modes = [False, True] # Compile both modes + + for allow_microbatching in microbatching_modes: + for _ in range(self.compilation_config.cudagraph_num_of_warmups): + # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. + # But be careful, warm up with `NONE`is orthogonal to + # if we want to warm up attention or not. This is + # different from the case where `FULL` implies capture + # attention while `PIECEWISE` implies no attention. + force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL + self._dummy_run( + num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + activate_lora=activate_lora, + ) self._dummy_run( num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - force_attention=force_attention, + cudagraph_runtime_mode=cudagraph_runtime_mode, uniform_decode=uniform_decode, allow_microbatching=allow_microbatching, skip_eplb=True, remove_lora=False, activate_lora=activate_lora, ) - self._dummy_run( - num_tokens, - cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=uniform_decode, - allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False, - activate_lora=activate_lora, - ) self.maybe_remove_all_loras(self.lora_config) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 9de123263755..7614c21eaf3e 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -23,6 +23,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.import_utils import has_deep_gemm +from vllm.v1.worker.ubatch_utils import check_cudagraph_threshold from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts logger = init_logger(__name__) @@ -388,7 +389,13 @@ def __call__(self, *args, **kwargs): # for this shape during a normal run. if cudagraph_runtime_mode is CUDAGraphMode.FULL: assert batch_descriptor is not None - if batch_descriptor.num_tokens in self.cudagraphs: + num_tokens = batch_descriptor.num_tokens + uniform_decode = batch_descriptor.uniform_decode + # Check if the model runner made a non-dbo cudagraph for this shape + cudagraph_exists = check_cudagraph_threshold( + self.vllm_config.parallel_config, num_tokens, uniform_decode + ) + if num_tokens in self.cudagraphs and not cudagraph_exists: cudagraph_runtime_mode = CUDAGraphMode.NONE if cudagraph_runtime_mode in (CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE): diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index 33a1921d2d98..fce91e532c25 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -42,6 +42,12 @@ def check_ubatch_thresholds( return num_tokens >= config.dbo_prefill_token_threshold +def check_cudagraph_threshold( + config: ParallelConfig, num_tokens: int, uniform_decode: bool +): + return uniform_decode and num_tokens <= config.dbo_decode_token_threshold * 1.5 + + def create_ubatch_slices( num_scheduled_tokens: np.ndarray, split_point: int ) -> UBatchSlices: