Skip to content

Conversation

@jhaotingc
Copy link
Collaborator

@jhaotingc jhaotingc commented Dec 24, 2025

Description

Previously, the entire draft forward step was wrapped in modeling_speculative.py. This is fast but not flexible enough for additional sampling, CPU operation, etc.

Now, a flag is introduced to move draft loop out of cuda graph capture.

After this PR, with enable_cuda_graph_for_draft_model=False, the draft loops are not captured but still queued in kernel runs.
Screenshot 2025-12-23 175537

Before this PR. The drafting loop is included in cuda graph capture.
image

The function is not enabled by default to prevent perf loss.

Guided decoder functionality is being fixed.

Guided decoder perf is being fixed.
image

Summary by CodeRabbit

  • New Features

    • Added CUDA graph support configuration option for Eagle3 speculative decoding, enabling GPU acceleration of draft model inference. New parameter enable_cuda_graph_for_draft_model defaults to True, with validation ensuring compatibility with Eagle3 one-model configuration.
  • Tests

    • Expanded Eagle3 test coverage with new parameter combinations to validate CUDA graph behavior across various decoding configurations.

✏️ Tip: You can customize this high-level summary in your review settings.

Test Coverage

tried several combination, all passed accuracy test. Hand-checked acceptance length.

LLM_MODELS_ROOT=/trt_llm/data/llm-models pytest -s -vv tests/integration/defs/accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_eagle3[enable_cuda_graph_for_draft_model=False-sampler_async_worker=True-eagle3_one_model=True-overlap_scheduler=True]

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@jhaotingc jhaotingc requested review from a team as code owners December 24, 2025 01:57
@jhaotingc
Copy link
Collaborator Author

/bot run --disable-fail-fast

@jhaotingc jhaotingc requested a review from syuoni December 24, 2025 02:07
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 24, 2025

📝 Walkthrough

Walkthrough

This pull request introduces a new configuration flag enable_cuda_graph_for_draft_model to control draft model execution within CUDA graph capture contexts. When disabled during graph capture, the draft path is skipped at capture time and executed separately afterward, ensuring backward compatibility while enabling flexible CUDA graph handling for speculative decoding.

Changes

Cohort / File(s) Summary
Core Speculative Model Implementation
tensorrt_llm/_torch/models/modeling_speculative.py
Added enable_cuda_graph_for_draft_model flag to spec config handling; introduces runtime check using is_current_stream_capturing() to conditionally skip draft execution during graph capture; extracted draft-path logic into new forward_draft() helper method
Executor Integration
tensorrt_llm/_torch/pyexecutor/model_engine.py
Stores enable_cuda_graph_for_draft_model from spec_config on PyTorchModelEngine; adds post-CUDA-graph-replay branch to execute forward_draft() when flag is False, ensuring draft model output refinement with original inputs
Configuration API
tensorrt_llm/llmapi/llm_args.py
Added public field enable_cuda_graph_for_draft_model: Optional[bool] = True to EagleDecodingConfig; added validation guard in __init__ raising ValueError when flag is False while eagle3_one_model is False
Test Configuration Metadata
tests/integration/defs/.test_durations
Updated two test identifiers to include enable_cuda_graph_for_draft_model parameter in Eagle3 test keys; durations remain unchanged
Test Implementation & Parameterization
tests/integration/defs/accuracy/test_llm_api_pytorch.py, tests/integration/test_lists/qa/llm_digits_func.txt, tests/integration/test_lists/qa/llm_function_core.txt, tests/integration/test_lists/qa/llm_function_core_sanity.txt, tests/integration/test_lists/qa/llm_function_l20.txt, tests/integration/test_lists/qa/llm_function_rtx6k.txt
Updated test_eagle3() to accept enable_cuda_graph_for_draft_model parameter and conditionally skip when False with eagle3_one_model False; replaced three old test variants with four new parameterized variants across multiple test configuration lists
Test Database Configuration
tests/integration/test_lists/test-db/l0_h100.yml
Expanded Eagle3 test matrix by replacing three variants with four new variants that include enable_cuda_graph_for_draft_model flag combinations

Sequence Diagram(s)

sequenceDiagram
    participant Client as Client/Main Forward
    participant Model as Speculative Model
    participant Draft as Draft Model
    participant Executor as PyTorch Executor

    rect rgb(200, 220, 255)
    Note over Client,Executor: Standard Path: enable_cuda_graph_for_draft_model = True
    Client->>Executor: Forward with CUDA graph
    Executor->>Model: forward() with graph capture active
    Model->>Draft: forward_draft() [captured in graph]
    Draft-->>Model: draft logits
    Model-->>Executor: spec results
    Executor-->>Client: output
    end

    rect rgb(255, 220, 200)
    Note over Client,Executor: Alternate Path: enable_cuda_graph_for_draft_model = False
    Client->>Executor: Forward with CUDA graph
    Executor->>Model: forward() with graph capture active
    Model-->>Executor: hidden_states [early return, skip draft]
    Executor->>Model: forward_draft() [post-graph-replay]
    Model->>Draft: Compute draft logits with original inputs
    Draft-->>Model: draft logits
    Model-->>Executor: refined outputs
    Executor-->>Client: output
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~22 minutes

Possibly related PRs

Suggested labels

Model customization

Suggested reviewers

  • mikeiovine
  • syuoni
  • ziyixiong-nv
  • lfr-0531

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description check ❓ Inconclusive PR description explains the feature but lacks clear structure and is missing required template sections like PR title format and comprehensive test coverage details. Add PR title following the template format [ticket][type] Summary, expand Test Coverage section with specific test cases, and ensure all PR Checklist items are properly addressed with details.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main feature: introducing support for running eagle3 one-model draft loop outside CUDA graph capture.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
tensorrt_llm/llmapi/llm_args.py (1)

917-921: Use is False for explicit boolean checks with Optional[bool].

The static analysis flags == False comparisons. Since enable_cuda_graph_for_draft_model is Optional[bool], using is False is more explicit and avoids potential issues with truthiness comparisons:

🔎 Suggested fix
-        if self.enable_cuda_graph_for_draft_model == False and self.eagle3_one_model == False:
+        if self.enable_cuda_graph_for_draft_model is False and self.eagle3_one_model is False:
             raise ValueError(
                 "enable_cuda_graph_for_draft_model can be false only when eagle3_one_model is True"
             )
tests/integration/defs/accuracy/test_llm_api_pytorch.py (1)

269-275: LGTM with optional style improvement.

The new test parameter correctly covers both enabling and disabling CUDA graph capture for the draft model. The skip condition appropriately prevents invalid test combinations where the flag is disabled in two-model mode.

Optional: Address linter suggestion for more Pythonic comparison
-        if enable_cuda_graph_for_draft_model == False and eagle3_one_model == False:
+        if not enable_cuda_graph_for_draft_model and not eagle3_one_model:
             pytest.skip(
                 "enable_cuda_graph_for_draft_model can be false only when eagle3_one_model is True"
             )
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 92d90fa and 9ca412a.

📒 Files selected for processing (11)
  • tensorrt_llm/_torch/models/modeling_speculative.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/llmapi/llm_args.py
  • tests/integration/defs/.test_durations
  • tests/integration/defs/accuracy/test_llm_api_pytorch.py
  • tests/integration/test_lists/qa/llm_digits_func.txt
  • tests/integration/test_lists/qa/llm_function_core.txt
  • tests/integration/test_lists/qa/llm_function_core_sanity.txt
  • tests/integration/test_lists/qa/llm_function_l20.txt
  • tests/integration/test_lists/qa/llm_function_rtx6k.txt
  • tests/integration/test_lists/test-db/l0_h100.yml
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces. Do not use tabs
Always maintain the namespace when importing in Python, even if only one class or function from a module is used
Python files should use snake_case naming: some_file.py
Python classes should use PascalCase naming: class SomeClass
Python functions and methods should use snake_case naming: def my_awesome_function():
Python local variables should use snake_case naming: my_variable = ...
Python variable names that start with a number should be prefixed with 'k': k_99th_percentile = ...
Python global variables should use upper snake_case with prefix 'G': G_MY_GLOBAL = ...
Python constants should use upper snake_case naming: MY_CONSTANT = ...
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Python comments should be reserved for code within a function, or interfaces that are local to a file
Use Google style docstrings in Python for classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with type and description
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except to the smallest set of errors possible
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible, using the else block for logic

Files:

  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tests/integration/defs/accuracy/test_llm_api_pytorch.py
  • tensorrt_llm/_torch/models/modeling_speculative.py
**/*.{cpp,h,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification

Files:

  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tests/integration/defs/accuracy/test_llm_api_pytorch.py
  • tensorrt_llm/_torch/models/modeling_speculative.py
🧠 Learnings (12)
📓 Common learnings
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.
Learnt from: nzmora-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 9163
File: tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py:107-113
Timestamp: 2025-11-14T11:22:03.729Z
Learning: In TensorRT-LLM AutoDeploy custom ops, when adding hardware capability checks to select between kernel implementations (e.g., cuBLAS vs. CUDA kernel), use descriptive variable names that identify the specific GPU architectures or families being targeted (e.g., `is_blackwell_geforce_or_ada`) rather than generic names like `enable_cuda_core`. This makes it clear that the code is selecting an implementation path based on hardware capabilities, not enabling/disabling hardware features.
📚 Learning: 2025-09-09T09:40:45.658Z
Learnt from: fredricz-20070104
Repo: NVIDIA/TensorRT-LLM PR: 7645
File: tests/integration/test_lists/qa/llm_function_core.txt:648-648
Timestamp: 2025-09-09T09:40:45.658Z
Learning: In TensorRT-LLM test lists, it's common and intentional for the same test to appear in multiple test list files when they serve different purposes (e.g., llm_function_core.txt for comprehensive core functionality testing and llm_function_core_sanity.txt for quick sanity checks). This duplication allows tests to be run in different testing contexts.

Applied to files:

  • tests/integration/test_lists/qa/llm_function_core_sanity.txt
  • tests/integration/defs/.test_durations
  • tests/integration/test_lists/test-db/l0_h100.yml
  • tests/integration/test_lists/qa/llm_function_core.txt
  • tests/integration/test_lists/qa/llm_function_l20.txt
  • tests/integration/test_lists/qa/llm_function_rtx6k.txt
  • tests/integration/test_lists/qa/llm_digits_func.txt
📚 Learning: 2025-08-26T09:49:04.956Z
Learnt from: pengbowang-nv
Repo: NVIDIA/TensorRT-LLM PR: 7192
File: tests/integration/test_lists/test-db/l0_dgx_b200.yml:56-72
Timestamp: 2025-08-26T09:49:04.956Z
Learning: In TensorRT-LLM test configuration files, the test scheduling system handles wildcard matching with special rules that prevent duplicate test execution even when the same tests appear in multiple yaml files with overlapping GPU wildcards (e.g., "*b200*" and "*gb200*").

Applied to files:

  • tests/integration/test_lists/qa/llm_function_core_sanity.txt
  • tests/integration/test_lists/test-db/l0_h100.yml
  • tests/integration/test_lists/qa/llm_function_core.txt
  • tests/integration/test_lists/qa/llm_function_l20.txt
  • tests/integration/test_lists/qa/llm_function_rtx6k.txt
  • tests/integration/test_lists/qa/llm_digits_func.txt
📚 Learning: 2025-09-17T02:48:52.732Z
Learnt from: tongyuantongyu
Repo: NVIDIA/TensorRT-LLM PR: 7781
File: tests/integration/test_lists/waives.txt:313-313
Timestamp: 2025-09-17T02:48:52.732Z
Learning: In TensorRT-LLM, `tests/integration/test_lists/waives.txt` is specifically for waiving/skipping tests, while other test list files like those in `test-db/` and `qa/` directories are for different test execution contexts (pre-merge, post-merge, QA tests). The same test appearing in both waives.txt and execution list files is intentional - the test is part of test suites but will be skipped due to the waiver.

Applied to files:

  • tests/integration/test_lists/qa/llm_function_core_sanity.txt
  • tests/integration/test_lists/test-db/l0_h100.yml
  • tests/integration/test_lists/qa/llm_function_core.txt
  • tests/integration/test_lists/qa/llm_function_rtx6k.txt
📚 Learning: 2025-11-27T09:23:18.742Z
Learnt from: fredricz-20070104
Repo: NVIDIA/TensorRT-LLM PR: 9511
File: tests/integration/defs/examples/serve/test_serve.py:136-186
Timestamp: 2025-11-27T09:23:18.742Z
Learning: In TensorRT-LLM testing, when adding test cases based on RCCA commands, the command format should be copied exactly as it appears in the RCCA case, even if it differs from existing tests. For example, some RCCA commands for trtllm-serve may omit the "serve" subcommand while others include it.

Applied to files:

  • tests/integration/test_lists/qa/llm_function_core_sanity.txt
  • tests/integration/test_lists/qa/llm_function_rtx6k.txt
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
Repo: NVIDIA/TensorRT-LLM PR: 6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tests/integration/test_lists/qa/llm_function_core_sanity.txt
  • tests/integration/defs/.test_durations
  • tests/integration/test_lists/test-db/l0_h100.yml
  • tests/integration/test_lists/qa/llm_function_core.txt
  • tests/integration/test_lists/qa/llm_function_l20.txt
  • tests/integration/test_lists/qa/llm_function_rtx6k.txt
  • tests/integration/test_lists/qa/llm_digits_func.txt
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
Repo: NVIDIA/TensorRT-LLM PR: 6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tests/integration/test_lists/test-db/l0_h100.yml
  • tests/integration/test_lists/qa/llm_function_rtx6k.txt
📚 Learning: 2025-08-26T09:37:10.463Z
Learnt from: jiaganc
Repo: NVIDIA/TensorRT-LLM PR: 7031
File: tensorrt_llm/bench/dataclasses/configuration.py:90-104
Timestamp: 2025-08-26T09:37:10.463Z
Learning: In TensorRT-LLM, the `get_pytorch_perf_config()` method returns `self.pytorch_config` which can contain default `cuda_graph_config` values, so `llm_args` may already have this config before the extra options processing.

Applied to files:

  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tests/integration/defs/accuracy/test_llm_api_pytorch.py
📚 Learning: 2025-08-26T09:37:10.463Z
Learnt from: jiaganc
Repo: NVIDIA/TensorRT-LLM PR: 7031
File: tensorrt_llm/bench/dataclasses/configuration.py:90-104
Timestamp: 2025-08-26T09:37:10.463Z
Learning: In TensorRT-LLM's bench configuration, the `get_pytorch_perf_config()` method returns `self.pytorch_config` which is a Dict[str, Any] that can contain default values including `cuda_graph_config`, making the fallback `llm_args["cuda_graph_config"]` safe to use.

Applied to files:

  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tests/integration/defs/accuracy/test_llm_api_pytorch.py
📚 Learning: 2025-08-19T12:45:11.997Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/model_engine.py
📚 Learning: 2025-11-14T11:22:03.729Z
Learnt from: nzmora-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 9163
File: tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py:107-113
Timestamp: 2025-11-14T11:22:03.729Z
Learning: In TensorRT-LLM AutoDeploy custom ops, when adding hardware capability checks to select between kernel implementations (e.g., cuBLAS vs. CUDA kernel), use descriptive variable names that identify the specific GPU architectures or families being targeted (e.g., `is_blackwell_geforce_or_ada`) rather than generic names like `enable_cuda_core`. This makes it clear that the code is selecting an implementation path based on hardware capabilities, not enabling/disabling hardware features.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/model_engine.py
📚 Learning: 2025-12-12T03:27:08.565Z
Learnt from: tongyuantongyu
Repo: NVIDIA/TensorRT-LLM PR: 9655
File: tensorrt_llm/_torch/pyexecutor/sampler.py:3031-3031
Timestamp: 2025-12-12T03:27:08.565Z
Learning: In files under tensorrt_llm/_torch/pyexecutor, avoid accessing torch.Tensor objects inside for-loops when iterating over requests. Convert batched tensors to Python lists beforehand using tensor.tolist(), and then iterate over those lists. This improves performance by reducing tensor-bound operations inside hot loops. Apply this pattern to similar code paths that process batches to access simple Python data structures (lists) inside loops.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/model_engine.py
🧬 Code graph analysis (2)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
tensorrt_llm/_torch/models/modeling_speculative.py (1)
  • forward_draft (846-872)
tensorrt_llm/_torch/models/modeling_speculative.py (3)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
  • forward (77-84)
  • forward (3150-3281)
tensorrt_llm/_torch/speculative/eagle3.py (1)
  • forward (373-504)
tensorrt_llm/_torch/speculative/mtp.py (2)
  • forward (362-559)
  • forward (1200-1388)
🪛 Ruff (0.14.10)
tensorrt_llm/llmapi/llm_args.py

917-917: Avoid equality comparisons to False; use not self.enable_cuda_graph_for_draft_model: for false checks

Replace with not self.enable_cuda_graph_for_draft_model

(E712)


917-917: Avoid equality comparisons to False; use not self.eagle3_one_model: for false checks

Replace with not self.eagle3_one_model

(E712)


918-920: Avoid specifying long messages outside the exception class

(TRY003)

tests/integration/defs/accuracy/test_llm_api_pytorch.py

272-272: Avoid equality comparisons to False; use not enable_cuda_graph_for_draft_model: for false checks

Replace with not enable_cuda_graph_for_draft_model

(E712)


272-272: Avoid equality comparisons to False; use not eagle3_one_model: for false checks

Replace with not eagle3_one_model

(E712)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (10)
tests/integration/test_lists/qa/llm_function_core.txt (1)

393-393: Test configuration may conflict with validation logic.

Line 393 specifies enable_cuda_graph_for_draft_model=False with eagle3_one_model=False. However, the new validation in EagleDecodingConfig.__init__ (llm_args.py lines 917-920) explicitly raises a ValueError when this combination is used:

"enable_cuda_graph_for_draft_model can be false only when eagle3_one_model is True"

Please verify whether this test is intentionally testing the error path (negative test) or if the parameters need to be corrected.

tensorrt_llm/llmapi/llm_args.py (1)

861-862: LGTM!

The new enable_cuda_graph_for_draft_model field is appropriately added with a clear comment and a sensible default value of True to maintain backward compatibility.

tests/integration/defs/accuracy/test_llm_api_pytorch.py (1)

292-296: LGTM!

The new enable_cuda_graph_for_draft_model parameter is correctly propagated to EagleDecodingConfig, following the same pattern as the other configuration parameters.

tests/integration/defs/.test_durations (1)

312-313: Test duration entries may be incomplete for new test variants.

The test duration file only includes entries for enable_cuda_graph_for_draft_model=False variants. However, llm_digits_func.txt also adds enable_cuda_graph_for_draft_model=True-sampler_async_worker=True-eagle3_one_model=True-overlap_scheduler=True. Consider adding duration estimates for the new True variant to ensure accurate test scheduling.

Also note that the new test parameterization includes sampler_async_worker which is not reflected in these duration keys.

tensorrt_llm/_torch/models/modeling_speculative.py (1)

846-872: LGTM - clean extraction of draft forwarding logic.

The forward_draft method cleanly encapsulates the draft-path logic. The handling of padded tokens (lines 857-863) correctly slices inputs before passing to the spec worker.

Consider adding a brief docstring describing what this method does and its return type for maintainability.

tests/integration/test_lists/qa/llm_digits_func.txt (1)

19-22: No action required. The test case at line 19 is properly handled.

The test_eagle3 function includes an explicit guard at lines 272-275 that skips the test when enable_cuda_graph_for_draft_model=False and eagle3_one_model=False. This combination at line 19 will be skipped by pytest, not fail during initialization. Including this parameter combination in the test list is a valid pattern for documenting all possible parameter configurations, even when some are intentionally skipped.

Likely an incorrect or invalid review comment.

tests/integration/test_lists/qa/llm_function_rtx6k.txt (1)

62-65: Eagle3 CUDA-graph draft-flag variants look well-structured

The four Llama3.1-8B Eagle3 variants cover both values of enable_cuda_graph_for_draft_model across relevant sampler_async_worker / eagle3_one_model / overlap_scheduler combinations; fits the rest of the matrix and naming conventions.

tests/integration/test_lists/qa/llm_function_core_sanity.txt (1)

131-134: Sanity list Eagle3 coverage matches full-function matrix

Mirrors the four Eagle3 combinations used elsewhere, so core-sanity will exercise the same enable_cuda_graph_for_draft_model / scheduler / worker matrix in a lighter run. No issues noted.

tests/integration/test_lists/qa/llm_function_l20.txt (1)

24-27: L20 single-GPU Eagle3 variants are consistent with other QA lists

Single-GPU list reuses the same four Eagle3 parameter combinations, so behavior is aligned across RTX6k/core-sanity/L20 contexts. Looks good.

tests/integration/test_lists/test-db/l0_h100.yml (1)

61-64: H100 Eagle3 test matrix appropriately exercises the new draft CUDA-graph flag

The four H100 Eagle3 entries span both values of enable_cuda_graph_for_draft_model with different sampler_async_worker / eagle3_one_model combinations, which should give good coverage of the new path on H100.

assert key in model_config.extra_attrs
model_config.extra_attrs[key].update(value)
self.layer_idx = -1
self.enable_cuda_graph_for_draft_model = spec_config.enable_cuda_graph_for_draft_model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Potential AttributeError when spec_config is None.

The spec_config variable can be None (assigned via getattr(model_config, 'spec_config', None) on line 741). Accessing spec_config.enable_cuda_graph_for_draft_model directly without a null check will raise an AttributeError.

🔎 Proposed fix
         self.layer_idx = -1
-        self.enable_cuda_graph_for_draft_model = spec_config.enable_cuda_graph_for_draft_model
+        self.enable_cuda_graph_for_draft_model = spec_config.enable_cuda_graph_for_draft_model if spec_config else True

Note: Defaulting to True preserves backward-compatible behavior (CUDA graph capture enabled by default).

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.enable_cuda_graph_for_draft_model = spec_config.enable_cuda_graph_for_draft_model
self.layer_idx = -1
self.enable_cuda_graph_for_draft_model = spec_config.enable_cuda_graph_for_draft_model if spec_config else True
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_speculative.py around line 799
(spec_config originates from getattr(..., 'spec_config', None) on line 741),
accessing spec_config.enable_cuda_graph_for_draft_model can raise AttributeError
when spec_config is None; change the assignment to read the attribute safely,
e.g. set self.enable_cuda_graph_for_draft_model = getattr(spec_config,
'enable_cuda_graph_for_draft_model', True) or check if spec_config is not None
before accessing and default to True to preserve backward-compatible behavior.

Comment on lines +827 to +835
is_capturing = torch.cuda.is_current_stream_capturing()

if self.draft_model is not None:
# get logits
logits = self.logits_processor.forward(
hidden_states[spec_metadata.gather_ids],
self.lm_head,
attn_metadata,
True,
)
mtp_input_ids = input_ids
mtp_position_ids = position_ids
if attn_metadata.padded_num_tokens is not None:
if input_ids is not None:
# Slice along the first dimension
mtp_input_ids = input_ids[:attn_metadata.num_tokens]
if position_ids is not None:
# Slice along the last dimension
mtp_position_ids = position_ids[:, :attn_metadata.
num_tokens]

# get accepted tokens and next draft tokens
return self.spec_worker(input_ids=mtp_input_ids,
position_ids=mtp_position_ids,
hidden_states=hidden_states,
logits=logits,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata,
draft_model=self.draft_model)
if is_capturing and not self.enable_cuda_graph_for_draft_model:
return hidden_states
else:
return self.forward_draft(hidden_states, input_ids,
position_ids, attn_metadata,
spec_metadata)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, find the method that contains lines 827-835
head -835 tensorrt_llm/_torch/models/modeling_speculative.py | tail -50

Repository: NVIDIA/TensorRT-LLM

Length of output: 2350


🏁 Script executed:

#!/bin/bash
# Also check for forward_draft method and its return type
rg -n "def forward_draft" tensorrt_llm/_torch/models/modeling_speculative.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 134


🏁 Script executed:

#!/bin/bash
# Find where this code is called in model_engine.py
rg -n "self.draft_model" tensorrt_llm/_torch/pyexecutor/model_engine.py -B 5 -A 15

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

#!/bin/bash
# Get the forward_draft method signature and implementation
sed -n '846,900p' tensorrt_llm/_torch/models/modeling_speculative.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2456


🏁 Script executed:

#!/bin/bash
# Find where SpeculativeHeadWrapper.forward is called
rg -n "\.forward(" tensorrt_llm/_torch/pyexecutor/model_engine.py -B 2 -A 5

Repository: NVIDIA/TensorRT-LLM

Length of output: 133


🏁 Script executed:

#!/bin/bash
# Search for usage of the return value from the model's forward/forward_draft
rg -n "forward_draft\|SpeculativeHeadWrapper" tensorrt_llm/_torch/pyexecutor/

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

#!/bin/bash
# Fix the regex and search for forward calls
rg -n "\.forward\(" tensorrt_llm/_torch/pyexecutor/model_engine.py -B 2 -A 5

Repository: NVIDIA/TensorRT-LLM

Length of output: 2751


🏁 Script executed:

#!/bin
# Search for SpeculativeHeadWrapper usage
rg -n "SpeculativeHeadWrapper" tensorrt_llm/_torch/

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

#!/bin/bash
# Look at the full forward method to understand return type
sed -n '800,850p' tensorrt_llm/_torch/models/modeling_speculative.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1982


🏁 Script executed:

#!/bin/bash
# Search more broadly for how forward is used and what happens to its return value
rg -n "model.forward\|\.forward\(\)" tensorrt_llm/_torch/pyexecutor/ -A 3 | head -100

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

#!/bin/bash
# Check what spec_worker returns
rg -n "class.*SpecWorker\|def __call__" tensorrt_llm/_torch/ | grep -i spec

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

#!/bin/bash
# Look for spec_worker definition and return type
fd "spec_worker" tensorrt_llm/_torch/ --type f

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

#!/bin/bash
# Let's check what self.spec_worker is - search for its definition/assignment
rg -n "self\.spec_worker\s*=" tensorrt_llm/_torch/models/modeling_speculative.py -B 2 -A 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 438


🏁 Script executed:

#!/bin/bash
# Search for get_spec_worker function
rg -n "def get_spec_worker\|from.*spec_worker" tensorrt_llm/_torch/

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

#!/bin/bash
# Look at the broader file structure to understand the class
head -50 tensorrt_llm/_torch/models/modeling_speculative.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 2088


🏁 Script executed:

#!/bin/bash
# Find get_spec_worker definition
rg -n "def get_spec_worker" tensorrt_llm/ -A 30

Repository: NVIDIA/TensorRT-LLM

Length of output: 2713


🏁 Script executed:

#!/bin/bash
# Look for speculative module files
fd "speculative" tensorrt_llm/_torch/ --type f

Repository: NVIDIA/TensorRT-LLM

Length of output: 113


🏁 Script executed:

#!/bin/bash
# Check the return type annotation in forward method more carefully
sed -n '814,846p' tensorrt_llm/_torch/models/modeling_speculative.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1323


🏁 Script executed:

#!/bin/bash
# Check what the worker classes return by looking at their __call__ methods
rg -n "class MTPWorker\|class MTPEagleWorker\|class Eagle3OneModelWorker" tensorrt_llm/_torch/speculative/ -A 50 | head -150

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

#!/bin/bash
# Search for __call__ method in speculative worker classes
rg -n "def __call__" tensorrt_llm/_torch/speculative/ -A 15

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

#!/bin/bash
# Look at full forward method and its return type annotation
sed -n '807,851p' tensorrt_llm/_torch/models/modeling_speculative.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1774


🏁 Script executed:

#!/bin/bash
# Check class definition and where forward is called
rg -n "class SpeculativeHeadWrapper" tensorrt_llm/_torch/models/modeling_speculative.py -A 5

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

#!/bin/bash
# Look at how the model output is used after calling forward
rg -n "SpeculativeHeadWrapper\|speculative_head_wrapper" tensorrt_llm/_torch/pyexecutor/model_engine.py -B 5 -A 10

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

#!/bin/bash
# Check what modules/layers use SpeculativeHeadWrapper
rg -n "SpeculativeHeadWrapper" tensorrt_llm/_torch/ -B 2 -A 2

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

#!/bin/bash
# Search for the class name from the file context
rg -n "class.*Wrapper" tensorrt_llm/_torch/models/modeling_speculative.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

#!/bin/bash
# Look at imports and overall file structure
head -100 tensorrt_llm/_torch/models/modeling_speculative.py | tail -50

Repository: NVIDIA/TensorRT-LLM

Length of output: 2201


🏁 Script executed:

#!/bin/bash
# Search for where speculative models are instantiated
rg -n "modeling_speculative\|from.*modeling_speculative" tensorrt_llm/_torch/ -l

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

#!/bin/bash
# Find the class that contains this forward method - search backwards from line 807
sed -n '700,810p' tensorrt_llm/_torch/models/modeling_speculative.py | grep "^class"

Repository: NVIDIA/TensorRT-LLM

Length of output: 138


🏁 Script executed:

#!/bin/bash
# Let's get a broader view of the file structure
rg -n "^class " tensorrt_llm/_torch/models/modeling_speculative.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 611


🏁 Script executed:

#!/bin/bash
# Look at the entire context around the forward method to find the class
sed -n '750,815p' tensorrt_llm/_torch/models/modeling_speculative.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 3735


🏁 Script executed:

#!/bin/bash
# Now search for where SpecDecOneEngineForCausalLM is used and how the forward result is handled
rg -n "SpecDecOneEngineForCausalLM" tensorrt_llm/_torch/ -B 2 -A 5

Repository: NVIDIA/TensorRT-LLM

Length of output: 14296


🏁 Script executed:

#!/bin/bash
# Look for the model_engine forward method to see how it uses the model output
rg -n "def forward\(" tensorrt_llm/_torch/pyexecutor/model_engine.py | head -20

Repository: NVIDIA/TensorRT-LLM

Length of output: 114


🏁 Script executed:

#!/bin/bash
# Check line around 3297-3300 that we saw earlier which calls model.forward
sed -n '3290,3310p' tensorrt_llm/_torch/pyexecutor/model_engine.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1102


🏁 Script executed:

#!/bin/bash
# Check what spec_worker types return
rg -n "class.*Worker.*:" tensorrt_llm/_torch/speculative/ -A 30 | grep -A 25 "class MTPWorker\|class Eagle3OneModelWorker"

Repository: NVIDIA/TensorRT-LLM

Length of output: 4316


🏁 Script executed:

#!/bin/bash
# Look for the __call__ method in worker classes
fd "speculative" tensorrt_llm/_torch/ --type f -name "*.py"

Repository: NVIDIA/TensorRT-LLM

Length of output: 293


🏁 Script executed:

#!/bin/bash
# Check worker implementations
head -300 tensorrt_llm/_torch/speculative/worker.py | tail -200

Repository: NVIDIA/TensorRT-LLM

Length of output: 163


🏁 Script executed:

#!/bin/bash
# Find the return statement in MTPWorker.forward
rg -n "return" tensorrt_llm/_torch/speculative/mtp.py | grep -A 2 "MTPWorker\|def forward" | head -30

Repository: NVIDIA/TensorRT-LLM

Length of output: 45


🏁 Script executed:

#!/bin/bash
# Get the full MTPWorker.forward method
sed -n '362,550p' tensorrt_llm/_torch/speculative/mtp.py | tail -100

Repository: NVIDIA/TensorRT-LLM

Length of output: 4632


🏁 Script executed:

#!/bin/bash
# Get the full Eagle3OneModelWorker.forward method
sed -n '373,600p' tensorrt_llm/_torch/speculative/eagle3.py | tail -150

Repository: NVIDIA/TensorRT-LLM

Length of output: 6246


🏁 Script executed:

#!/bin/bash
# Check if there's any type casting or handling of the return value
sed -n '3295,3320p' tensorrt_llm/_torch/pyexecutor/model_engine.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 992


🏁 Script executed:

#!/bin/bash
# Look at MTPWorker.forward return statement more carefully
rg -n "return {" tensorrt_llm/_torch/speculative/mtp.py | head -5

Repository: NVIDIA/TensorRT-LLM

Length of output: 142


🏁 Script executed:

#!/bin/bash
# Get exact return statement from MTPWorker
sed -n '500,530p' tensorrt_llm/_torch/speculative/mtp.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1629


🏁 Script executed:

#!/bin/bash
# Find the actual return statement in MTPWorker.forward
rg -n "return \{" tensorrt_llm/_torch/speculative/mtp.py -A 10

Repository: NVIDIA/TensorRT-LLM

Length of output: 1963


🏁 Script executed:

#!/bin/bash
# Look at the end of MTPWorker.forward method
tail -200 tensorrt_llm/_torch/speculative/mtp.py | head -100

Repository: NVIDIA/TensorRT-LLM

Length of output: 5501


🏁 Script executed:

#!/antml/bin/bash
# Check what model_forward returns in model_engine
sed -n '3310,3350p' tensorrt_llm/_torch/pyexecutor/model_engine.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 1816


🏁 Script executed:

#!/bin/bash
# Get the exact method signature
sed -n '807,820p' tensorrt_llm/_torch/models/modeling_speculative.py

Repository: NVIDIA/TensorRT-LLM

Length of output: 539


Fix return type annotation.

The method's return type is annotated as torch.Tensor, but it can actually return either a torch.Tensor (when is_capturing and not self.enable_cuda_graph_for_draft_model) or a dict (from self.spec_worker). The spec_worker returns a dict with keys like logits, new_tokens, new_tokens_lens, etc., not a tensor. While PyTorchModelEngine._forward_step correctly handles both types using isinstance(outputs, dict), the method's return type annotation should be updated to reflect the actual return types (e.g., Union[torch.Tensor, Dict[str, Any]]).

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_speculative.py around lines 827 to 835,
the method is annotated to return torch.Tensor but actually may return either a
torch.Tensor or a dict from the spec_worker; update the return type annotation
to Union[torch.Tensor, Dict[str, Any]] (or typing.Any as appropriate),
add/import Union, Dict, Any from typing at the top of the file if not present,
and adjust any related type comments or stubs so the signature matches both
possible return types.

) or self.model_is_wrapped
self.max_draft_len = spec_config.max_draft_len
self.max_total_draft_tokens = spec_config.max_total_draft_tokens
self.enable_cuda_graph_for_draft_model = spec_config.enable_cuda_graph_for_draft_model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard enable_cuda_graph_for_draft_model and forward_draft usage for non-speculative executors

Right now enable_cuda_graph_for_draft_model is only assigned when spec_config is not None, but the CUDA-graph replay path reads it unconditionally and also assumes inputs['spec_metadata'] exists. For non-speculative executors (no spec_config) that still use CUDA graphs, this can lead to:

  • AttributeError on self.enable_cuda_graph_for_draft_model.
  • Or, once the attribute is initialized, an invalid attempt to call model.forward_draft without spec_metadata and on models that don’t implement that method.

You likely only intend to run forward_draft in speculative mode when the new flag is False. Suggest initializing the flag for all cases and gating the forward_draft call on self.enable_spec_decode as well.

Proposed fix: initialize flag safely and gate the `forward_draft` call
@@
-        self.llm_args = llm_args
-        self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0
-        self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0
+        self.llm_args = llm_args
+        self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0
+        self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0
@@
-        self.spec_config = spec_config
-        self.is_spec_decode = spec_config is not None
+        self.spec_config = spec_config
+        self.is_spec_decode = spec_config is not None
+        # Default to True so non-speculative executors never take the draft-only path.
+        self.enable_cuda_graph_for_draft_model = (
+            spec_config.enable_cuda_graph_for_draft_model
+            if spec_config is not None else True
+        )
         self.sparse_attention_config = None if is_draft_model else llm_args.sparse_attention_config
         self.enable_spec_decode = self.is_spec_decode
         self.is_draft_model = is_draft_model
@@
-                    else:
-                        with MoeLoadBalancerIterContext(moe_load_balancer):
-                            outputs = self.cuda_graph_runner.replay(key, inputs)
-                            if not self.enable_cuda_graph_for_draft_model:
-                                outputs = self.model.forward_draft(
-                                    outputs, inputs['input_ids'],
-                                    inputs['position_ids'],
-                                    inputs['attn_metadata'],
-                                    inputs['spec_metadata'])
+                    else:
+                        with MoeLoadBalancerIterContext(moe_load_balancer):
+                            outputs = self.cuda_graph_runner.replay(key, inputs)
+                            # When speculative decoding is enabled but we opted out of
+                            # capturing the draft loop in the CUDA graph, run the
+                            # draft-only pass after replay.
+                            if (self.enable_spec_decode
+                                    and not self.enable_cuda_graph_for_draft_model):
+                                outputs = self.model.forward_draft(
+                                    outputs,
+                                    inputs['input_ids'],
+                                    inputs['position_ids'],
+                                    inputs['attn_metadata'],
+                                    inputs['spec_metadata'],
+                                )

This keeps non-speculative flows and non-Eagle3 models on the existing path while enabling the new “draft outside CUDA graph” behavior only where spec_config and forward_draft are defined.

Also applies to: 3269-3274

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29690 [ run ] triggered by Bot. Commit: 9ca412a

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29690 [ run ] completed with state SUCCESS. Commit: 9ca412a
/LLM/main/L0_MergeRequest_PR pipeline #22805 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants