-
Notifications
You must be signed in to change notification settings - Fork 2k
[None][feat] Not CUDA graph captured eagle3 one-model draft loop #10251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Jhao-Ting Chen <[email protected]>
|
/bot run --disable-fail-fast |
📝 WalkthroughWalkthroughThis pull request introduces a new configuration flag Changes
Sequence Diagram(s)sequenceDiagram
participant Client as Client/Main Forward
participant Model as Speculative Model
participant Draft as Draft Model
participant Executor as PyTorch Executor
rect rgb(200, 220, 255)
Note over Client,Executor: Standard Path: enable_cuda_graph_for_draft_model = True
Client->>Executor: Forward with CUDA graph
Executor->>Model: forward() with graph capture active
Model->>Draft: forward_draft() [captured in graph]
Draft-->>Model: draft logits
Model-->>Executor: spec results
Executor-->>Client: output
end
rect rgb(255, 220, 200)
Note over Client,Executor: Alternate Path: enable_cuda_graph_for_draft_model = False
Client->>Executor: Forward with CUDA graph
Executor->>Model: forward() with graph capture active
Model-->>Executor: hidden_states [early return, skip draft]
Executor->>Model: forward_draft() [post-graph-replay]
Model->>Draft: Compute draft logits with original inputs
Draft-->>Model: draft logits
Model-->>Executor: refined outputs
Executor-->>Client: output
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (2)
tensorrt_llm/llmapi/llm_args.py (1)
917-921: Useis Falsefor explicit boolean checks withOptional[bool].The static analysis flags
== Falsecomparisons. Sinceenable_cuda_graph_for_draft_modelisOptional[bool], usingis Falseis more explicit and avoids potential issues with truthiness comparisons:🔎 Suggested fix
- if self.enable_cuda_graph_for_draft_model == False and self.eagle3_one_model == False: + if self.enable_cuda_graph_for_draft_model is False and self.eagle3_one_model is False: raise ValueError( "enable_cuda_graph_for_draft_model can be false only when eagle3_one_model is True" )tests/integration/defs/accuracy/test_llm_api_pytorch.py (1)
269-275: LGTM with optional style improvement.The new test parameter correctly covers both enabling and disabling CUDA graph capture for the draft model. The skip condition appropriately prevents invalid test combinations where the flag is disabled in two-model mode.
Optional: Address linter suggestion for more Pythonic comparison
- if enable_cuda_graph_for_draft_model == False and eagle3_one_model == False: + if not enable_cuda_graph_for_draft_model and not eagle3_one_model: pytest.skip( "enable_cuda_graph_for_draft_model can be false only when eagle3_one_model is True" )
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
tensorrt_llm/_torch/models/modeling_speculative.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/llmapi/llm_args.pytests/integration/defs/.test_durationstests/integration/defs/accuracy/test_llm_api_pytorch.pytests/integration/test_lists/qa/llm_digits_func.txttests/integration/test_lists/qa/llm_function_core.txttests/integration/test_lists/qa/llm_function_core_sanity.txttests/integration/test_lists/qa/llm_function_l20.txttests/integration/test_lists/qa/llm_function_rtx6k.txttests/integration/test_lists/test-db/l0_h100.yml
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces. Do not use tabs
Always maintain the namespace when importing in Python, even if only one class or function from a module is used
Python files should use snake_case naming:some_file.py
Python classes should use PascalCase naming:class SomeClass
Python functions and methods should use snake_case naming:def my_awesome_function():
Python local variables should use snake_case naming:my_variable = ...
Python variable names that start with a number should be prefixed with 'k':k_99th_percentile = ...
Python global variables should use upper snake_case with prefix 'G':G_MY_GLOBAL = ...
Python constants should use upper snake_case naming:MY_CONSTANT = ...
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Python comments should be reserved for code within a function, or interfaces that are local to a file
Use Google style docstrings in Python for classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with type and description
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except to the smallest set of errors possible
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible, using the else block for logic
Files:
tensorrt_llm/llmapi/llm_args.pytensorrt_llm/_torch/pyexecutor/model_engine.pytests/integration/defs/accuracy/test_llm_api_pytorch.pytensorrt_llm/_torch/models/modeling_speculative.py
**/*.{cpp,h,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification
Files:
tensorrt_llm/llmapi/llm_args.pytensorrt_llm/_torch/pyexecutor/model_engine.pytests/integration/defs/accuracy/test_llm_api_pytorch.pytensorrt_llm/_torch/models/modeling_speculative.py
🧠 Learnings (12)
📓 Common learnings
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.
Learnt from: nzmora-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 9163
File: tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py:107-113
Timestamp: 2025-11-14T11:22:03.729Z
Learning: In TensorRT-LLM AutoDeploy custom ops, when adding hardware capability checks to select between kernel implementations (e.g., cuBLAS vs. CUDA kernel), use descriptive variable names that identify the specific GPU architectures or families being targeted (e.g., `is_blackwell_geforce_or_ada`) rather than generic names like `enable_cuda_core`. This makes it clear that the code is selecting an implementation path based on hardware capabilities, not enabling/disabling hardware features.
📚 Learning: 2025-09-09T09:40:45.658Z
Learnt from: fredricz-20070104
Repo: NVIDIA/TensorRT-LLM PR: 7645
File: tests/integration/test_lists/qa/llm_function_core.txt:648-648
Timestamp: 2025-09-09T09:40:45.658Z
Learning: In TensorRT-LLM test lists, it's common and intentional for the same test to appear in multiple test list files when they serve different purposes (e.g., llm_function_core.txt for comprehensive core functionality testing and llm_function_core_sanity.txt for quick sanity checks). This duplication allows tests to be run in different testing contexts.
Applied to files:
tests/integration/test_lists/qa/llm_function_core_sanity.txttests/integration/defs/.test_durationstests/integration/test_lists/test-db/l0_h100.ymltests/integration/test_lists/qa/llm_function_core.txttests/integration/test_lists/qa/llm_function_l20.txttests/integration/test_lists/qa/llm_function_rtx6k.txttests/integration/test_lists/qa/llm_digits_func.txt
📚 Learning: 2025-08-26T09:49:04.956Z
Learnt from: pengbowang-nv
Repo: NVIDIA/TensorRT-LLM PR: 7192
File: tests/integration/test_lists/test-db/l0_dgx_b200.yml:56-72
Timestamp: 2025-08-26T09:49:04.956Z
Learning: In TensorRT-LLM test configuration files, the test scheduling system handles wildcard matching with special rules that prevent duplicate test execution even when the same tests appear in multiple yaml files with overlapping GPU wildcards (e.g., "*b200*" and "*gb200*").
Applied to files:
tests/integration/test_lists/qa/llm_function_core_sanity.txttests/integration/test_lists/test-db/l0_h100.ymltests/integration/test_lists/qa/llm_function_core.txttests/integration/test_lists/qa/llm_function_l20.txttests/integration/test_lists/qa/llm_function_rtx6k.txttests/integration/test_lists/qa/llm_digits_func.txt
📚 Learning: 2025-09-17T02:48:52.732Z
Learnt from: tongyuantongyu
Repo: NVIDIA/TensorRT-LLM PR: 7781
File: tests/integration/test_lists/waives.txt:313-313
Timestamp: 2025-09-17T02:48:52.732Z
Learning: In TensorRT-LLM, `tests/integration/test_lists/waives.txt` is specifically for waiving/skipping tests, while other test list files like those in `test-db/` and `qa/` directories are for different test execution contexts (pre-merge, post-merge, QA tests). The same test appearing in both waives.txt and execution list files is intentional - the test is part of test suites but will be skipped due to the waiver.
Applied to files:
tests/integration/test_lists/qa/llm_function_core_sanity.txttests/integration/test_lists/test-db/l0_h100.ymltests/integration/test_lists/qa/llm_function_core.txttests/integration/test_lists/qa/llm_function_rtx6k.txt
📚 Learning: 2025-11-27T09:23:18.742Z
Learnt from: fredricz-20070104
Repo: NVIDIA/TensorRT-LLM PR: 9511
File: tests/integration/defs/examples/serve/test_serve.py:136-186
Timestamp: 2025-11-27T09:23:18.742Z
Learning: In TensorRT-LLM testing, when adding test cases based on RCCA commands, the command format should be copied exactly as it appears in the RCCA case, even if it differs from existing tests. For example, some RCCA commands for trtllm-serve may omit the "serve" subcommand while others include it.
Applied to files:
tests/integration/test_lists/qa/llm_function_core_sanity.txttests/integration/test_lists/qa/llm_function_rtx6k.txt
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
Repo: NVIDIA/TensorRT-LLM PR: 6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
Applied to files:
tests/integration/test_lists/qa/llm_function_core_sanity.txttests/integration/defs/.test_durationstests/integration/test_lists/test-db/l0_h100.ymltests/integration/test_lists/qa/llm_function_core.txttests/integration/test_lists/qa/llm_function_l20.txttests/integration/test_lists/qa/llm_function_rtx6k.txttests/integration/test_lists/qa/llm_digits_func.txt
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
Repo: NVIDIA/TensorRT-LLM PR: 6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Applied to files:
tests/integration/test_lists/test-db/l0_h100.ymltests/integration/test_lists/qa/llm_function_rtx6k.txt
📚 Learning: 2025-08-26T09:37:10.463Z
Learnt from: jiaganc
Repo: NVIDIA/TensorRT-LLM PR: 7031
File: tensorrt_llm/bench/dataclasses/configuration.py:90-104
Timestamp: 2025-08-26T09:37:10.463Z
Learning: In TensorRT-LLM, the `get_pytorch_perf_config()` method returns `self.pytorch_config` which can contain default `cuda_graph_config` values, so `llm_args` may already have this config before the extra options processing.
Applied to files:
tensorrt_llm/llmapi/llm_args.pytensorrt_llm/_torch/pyexecutor/model_engine.pytests/integration/defs/accuracy/test_llm_api_pytorch.py
📚 Learning: 2025-08-26T09:37:10.463Z
Learnt from: jiaganc
Repo: NVIDIA/TensorRT-LLM PR: 7031
File: tensorrt_llm/bench/dataclasses/configuration.py:90-104
Timestamp: 2025-08-26T09:37:10.463Z
Learning: In TensorRT-LLM's bench configuration, the `get_pytorch_perf_config()` method returns `self.pytorch_config` which is a Dict[str, Any] that can contain default values including `cuda_graph_config`, making the fallback `llm_args["cuda_graph_config"]` safe to use.
Applied to files:
tensorrt_llm/llmapi/llm_args.pytensorrt_llm/_torch/pyexecutor/model_engine.pytests/integration/defs/accuracy/test_llm_api_pytorch.py
📚 Learning: 2025-08-19T12:45:11.997Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.
Applied to files:
tensorrt_llm/_torch/pyexecutor/model_engine.py
📚 Learning: 2025-11-14T11:22:03.729Z
Learnt from: nzmora-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 9163
File: tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py:107-113
Timestamp: 2025-11-14T11:22:03.729Z
Learning: In TensorRT-LLM AutoDeploy custom ops, when adding hardware capability checks to select between kernel implementations (e.g., cuBLAS vs. CUDA kernel), use descriptive variable names that identify the specific GPU architectures or families being targeted (e.g., `is_blackwell_geforce_or_ada`) rather than generic names like `enable_cuda_core`. This makes it clear that the code is selecting an implementation path based on hardware capabilities, not enabling/disabling hardware features.
Applied to files:
tensorrt_llm/_torch/pyexecutor/model_engine.py
📚 Learning: 2025-12-12T03:27:08.565Z
Learnt from: tongyuantongyu
Repo: NVIDIA/TensorRT-LLM PR: 9655
File: tensorrt_llm/_torch/pyexecutor/sampler.py:3031-3031
Timestamp: 2025-12-12T03:27:08.565Z
Learning: In files under tensorrt_llm/_torch/pyexecutor, avoid accessing torch.Tensor objects inside for-loops when iterating over requests. Convert batched tensors to Python lists beforehand using tensor.tolist(), and then iterate over those lists. This improves performance by reducing tensor-bound operations inside hot loops. Apply this pattern to similar code paths that process batches to access simple Python data structures (lists) inside loops.
Applied to files:
tensorrt_llm/_torch/pyexecutor/model_engine.py
🧬 Code graph analysis (2)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
tensorrt_llm/_torch/models/modeling_speculative.py (1)
forward_draft(846-872)
tensorrt_llm/_torch/models/modeling_speculative.py (3)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
forward(77-84)forward(3150-3281)tensorrt_llm/_torch/speculative/eagle3.py (1)
forward(373-504)tensorrt_llm/_torch/speculative/mtp.py (2)
forward(362-559)forward(1200-1388)
🪛 Ruff (0.14.10)
tensorrt_llm/llmapi/llm_args.py
917-917: Avoid equality comparisons to False; use not self.enable_cuda_graph_for_draft_model: for false checks
Replace with not self.enable_cuda_graph_for_draft_model
(E712)
917-917: Avoid equality comparisons to False; use not self.eagle3_one_model: for false checks
Replace with not self.eagle3_one_model
(E712)
918-920: Avoid specifying long messages outside the exception class
(TRY003)
tests/integration/defs/accuracy/test_llm_api_pytorch.py
272-272: Avoid equality comparisons to False; use not enable_cuda_graph_for_draft_model: for false checks
Replace with not enable_cuda_graph_for_draft_model
(E712)
272-272: Avoid equality comparisons to False; use not eagle3_one_model: for false checks
Replace with not eagle3_one_model
(E712)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (10)
tests/integration/test_lists/qa/llm_function_core.txt (1)
393-393: Test configuration may conflict with validation logic.Line 393 specifies
enable_cuda_graph_for_draft_model=Falsewitheagle3_one_model=False. However, the new validation inEagleDecodingConfig.__init__(llm_args.py lines 917-920) explicitly raises aValueErrorwhen this combination is used:"enable_cuda_graph_for_draft_model can be false only when eagle3_one_model is True"
Please verify whether this test is intentionally testing the error path (negative test) or if the parameters need to be corrected.
tensorrt_llm/llmapi/llm_args.py (1)
861-862: LGTM!The new
enable_cuda_graph_for_draft_modelfield is appropriately added with a clear comment and a sensible default value ofTrueto maintain backward compatibility.tests/integration/defs/accuracy/test_llm_api_pytorch.py (1)
292-296: LGTM!The new
enable_cuda_graph_for_draft_modelparameter is correctly propagated toEagleDecodingConfig, following the same pattern as the other configuration parameters.tests/integration/defs/.test_durations (1)
312-313: Test duration entries may be incomplete for new test variants.The test duration file only includes entries for
enable_cuda_graph_for_draft_model=Falsevariants. However,llm_digits_func.txtalso addsenable_cuda_graph_for_draft_model=True-sampler_async_worker=True-eagle3_one_model=True-overlap_scheduler=True. Consider adding duration estimates for the newTruevariant to ensure accurate test scheduling.Also note that the new test parameterization includes
sampler_async_workerwhich is not reflected in these duration keys.tensorrt_llm/_torch/models/modeling_speculative.py (1)
846-872: LGTM - clean extraction of draft forwarding logic.The
forward_draftmethod cleanly encapsulates the draft-path logic. The handling of padded tokens (lines 857-863) correctly slices inputs before passing to the spec worker.Consider adding a brief docstring describing what this method does and its return type for maintainability.
tests/integration/test_lists/qa/llm_digits_func.txt (1)
19-22: No action required. The test case at line 19 is properly handled.The
test_eagle3function includes an explicit guard at lines 272-275 that skips the test whenenable_cuda_graph_for_draft_model=Falseandeagle3_one_model=False. This combination at line 19 will be skipped by pytest, not fail during initialization. Including this parameter combination in the test list is a valid pattern for documenting all possible parameter configurations, even when some are intentionally skipped.Likely an incorrect or invalid review comment.
tests/integration/test_lists/qa/llm_function_rtx6k.txt (1)
62-65: Eagle3 CUDA-graph draft-flag variants look well-structuredThe four Llama3.1-8B Eagle3 variants cover both values of
enable_cuda_graph_for_draft_modelacross relevantsampler_async_worker/eagle3_one_model/overlap_schedulercombinations; fits the rest of the matrix and naming conventions.tests/integration/test_lists/qa/llm_function_core_sanity.txt (1)
131-134: Sanity list Eagle3 coverage matches full-function matrixMirrors the four Eagle3 combinations used elsewhere, so core-sanity will exercise the same
enable_cuda_graph_for_draft_model/ scheduler / worker matrix in a lighter run. No issues noted.tests/integration/test_lists/qa/llm_function_l20.txt (1)
24-27: L20 single-GPU Eagle3 variants are consistent with other QA listsSingle-GPU list reuses the same four Eagle3 parameter combinations, so behavior is aligned across RTX6k/core-sanity/L20 contexts. Looks good.
tests/integration/test_lists/test-db/l0_h100.yml (1)
61-64: H100 Eagle3 test matrix appropriately exercises the new draft CUDA-graph flagThe four H100 Eagle3 entries span both values of
enable_cuda_graph_for_draft_modelwith differentsampler_async_worker/eagle3_one_modelcombinations, which should give good coverage of the new path on H100.
| assert key in model_config.extra_attrs | ||
| model_config.extra_attrs[key].update(value) | ||
| self.layer_idx = -1 | ||
| self.enable_cuda_graph_for_draft_model = spec_config.enable_cuda_graph_for_draft_model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential AttributeError when spec_config is None.
The spec_config variable can be None (assigned via getattr(model_config, 'spec_config', None) on line 741). Accessing spec_config.enable_cuda_graph_for_draft_model directly without a null check will raise an AttributeError.
🔎 Proposed fix
self.layer_idx = -1
- self.enable_cuda_graph_for_draft_model = spec_config.enable_cuda_graph_for_draft_model
+ self.enable_cuda_graph_for_draft_model = spec_config.enable_cuda_graph_for_draft_model if spec_config else TrueNote: Defaulting to True preserves backward-compatible behavior (CUDA graph capture enabled by default).
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| self.enable_cuda_graph_for_draft_model = spec_config.enable_cuda_graph_for_draft_model | |
| self.layer_idx = -1 | |
| self.enable_cuda_graph_for_draft_model = spec_config.enable_cuda_graph_for_draft_model if spec_config else True |
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_speculative.py around line 799
(spec_config originates from getattr(..., 'spec_config', None) on line 741),
accessing spec_config.enable_cuda_graph_for_draft_model can raise AttributeError
when spec_config is None; change the assignment to read the attribute safely,
e.g. set self.enable_cuda_graph_for_draft_model = getattr(spec_config,
'enable_cuda_graph_for_draft_model', True) or check if spec_config is not None
before accessing and default to True to preserve backward-compatible behavior.
| is_capturing = torch.cuda.is_current_stream_capturing() | ||
|
|
||
| if self.draft_model is not None: | ||
| # get logits | ||
| logits = self.logits_processor.forward( | ||
| hidden_states[spec_metadata.gather_ids], | ||
| self.lm_head, | ||
| attn_metadata, | ||
| True, | ||
| ) | ||
| mtp_input_ids = input_ids | ||
| mtp_position_ids = position_ids | ||
| if attn_metadata.padded_num_tokens is not None: | ||
| if input_ids is not None: | ||
| # Slice along the first dimension | ||
| mtp_input_ids = input_ids[:attn_metadata.num_tokens] | ||
| if position_ids is not None: | ||
| # Slice along the last dimension | ||
| mtp_position_ids = position_ids[:, :attn_metadata. | ||
| num_tokens] | ||
|
|
||
| # get accepted tokens and next draft tokens | ||
| return self.spec_worker(input_ids=mtp_input_ids, | ||
| position_ids=mtp_position_ids, | ||
| hidden_states=hidden_states, | ||
| logits=logits, | ||
| attn_metadata=attn_metadata, | ||
| spec_metadata=spec_metadata, | ||
| draft_model=self.draft_model) | ||
| if is_capturing and not self.enable_cuda_graph_for_draft_model: | ||
| return hidden_states | ||
| else: | ||
| return self.forward_draft(hidden_states, input_ids, | ||
| position_ids, attn_metadata, | ||
| spec_metadata) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, find the method that contains lines 827-835
head -835 tensorrt_llm/_torch/models/modeling_speculative.py | tail -50Repository: NVIDIA/TensorRT-LLM
Length of output: 2350
🏁 Script executed:
#!/bin/bash
# Also check for forward_draft method and its return type
rg -n "def forward_draft" tensorrt_llm/_torch/models/modeling_speculative.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 134
🏁 Script executed:
#!/bin/bash
# Find where this code is called in model_engine.py
rg -n "self.draft_model" tensorrt_llm/_torch/pyexecutor/model_engine.py -B 5 -A 15Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Get the forward_draft method signature and implementation
sed -n '846,900p' tensorrt_llm/_torch/models/modeling_speculative.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2456
🏁 Script executed:
#!/bin/bash
# Find where SpeculativeHeadWrapper.forward is called
rg -n "\.forward(" tensorrt_llm/_torch/pyexecutor/model_engine.py -B 2 -A 5Repository: NVIDIA/TensorRT-LLM
Length of output: 133
🏁 Script executed:
#!/bin/bash
# Search for usage of the return value from the model's forward/forward_draft
rg -n "forward_draft\|SpeculativeHeadWrapper" tensorrt_llm/_torch/pyexecutor/Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Fix the regex and search for forward calls
rg -n "\.forward\(" tensorrt_llm/_torch/pyexecutor/model_engine.py -B 2 -A 5Repository: NVIDIA/TensorRT-LLM
Length of output: 2751
🏁 Script executed:
#!/bin
# Search for SpeculativeHeadWrapper usage
rg -n "SpeculativeHeadWrapper" tensorrt_llm/_torch/Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Look at the full forward method to understand return type
sed -n '800,850p' tensorrt_llm/_torch/models/modeling_speculative.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1982
🏁 Script executed:
#!/bin/bash
# Search more broadly for how forward is used and what happens to its return value
rg -n "model.forward\|\.forward\(\)" tensorrt_llm/_torch/pyexecutor/ -A 3 | head -100Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Check what spec_worker returns
rg -n "class.*SpecWorker\|def __call__" tensorrt_llm/_torch/ | grep -i specRepository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Look for spec_worker definition and return type
fd "spec_worker" tensorrt_llm/_torch/ --type fRepository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Let's check what self.spec_worker is - search for its definition/assignment
rg -n "self\.spec_worker\s*=" tensorrt_llm/_torch/models/modeling_speculative.py -B 2 -A 2Repository: NVIDIA/TensorRT-LLM
Length of output: 438
🏁 Script executed:
#!/bin/bash
# Search for get_spec_worker function
rg -n "def get_spec_worker\|from.*spec_worker" tensorrt_llm/_torch/Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Look at the broader file structure to understand the class
head -50 tensorrt_llm/_torch/models/modeling_speculative.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2088
🏁 Script executed:
#!/bin/bash
# Find get_spec_worker definition
rg -n "def get_spec_worker" tensorrt_llm/ -A 30Repository: NVIDIA/TensorRT-LLM
Length of output: 2713
🏁 Script executed:
#!/bin/bash
# Look for speculative module files
fd "speculative" tensorrt_llm/_torch/ --type fRepository: NVIDIA/TensorRT-LLM
Length of output: 113
🏁 Script executed:
#!/bin/bash
# Check the return type annotation in forward method more carefully
sed -n '814,846p' tensorrt_llm/_torch/models/modeling_speculative.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1323
🏁 Script executed:
#!/bin/bash
# Check what the worker classes return by looking at their __call__ methods
rg -n "class MTPWorker\|class MTPEagleWorker\|class Eagle3OneModelWorker" tensorrt_llm/_torch/speculative/ -A 50 | head -150Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Search for __call__ method in speculative worker classes
rg -n "def __call__" tensorrt_llm/_torch/speculative/ -A 15Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Look at full forward method and its return type annotation
sed -n '807,851p' tensorrt_llm/_torch/models/modeling_speculative.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1774
🏁 Script executed:
#!/bin/bash
# Check class definition and where forward is called
rg -n "class SpeculativeHeadWrapper" tensorrt_llm/_torch/models/modeling_speculative.py -A 5Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Look at how the model output is used after calling forward
rg -n "SpeculativeHeadWrapper\|speculative_head_wrapper" tensorrt_llm/_torch/pyexecutor/model_engine.py -B 5 -A 10Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Check what modules/layers use SpeculativeHeadWrapper
rg -n "SpeculativeHeadWrapper" tensorrt_llm/_torch/ -B 2 -A 2Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Search for the class name from the file context
rg -n "class.*Wrapper" tensorrt_llm/_torch/models/modeling_speculative.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Look at imports and overall file structure
head -100 tensorrt_llm/_torch/models/modeling_speculative.py | tail -50Repository: NVIDIA/TensorRT-LLM
Length of output: 2201
🏁 Script executed:
#!/bin/bash
# Search for where speculative models are instantiated
rg -n "modeling_speculative\|from.*modeling_speculative" tensorrt_llm/_torch/ -lRepository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Find the class that contains this forward method - search backwards from line 807
sed -n '700,810p' tensorrt_llm/_torch/models/modeling_speculative.py | grep "^class"Repository: NVIDIA/TensorRT-LLM
Length of output: 138
🏁 Script executed:
#!/bin/bash
# Let's get a broader view of the file structure
rg -n "^class " tensorrt_llm/_torch/models/modeling_speculative.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 611
🏁 Script executed:
#!/bin/bash
# Look at the entire context around the forward method to find the class
sed -n '750,815p' tensorrt_llm/_torch/models/modeling_speculative.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 3735
🏁 Script executed:
#!/bin/bash
# Now search for where SpecDecOneEngineForCausalLM is used and how the forward result is handled
rg -n "SpecDecOneEngineForCausalLM" tensorrt_llm/_torch/ -B 2 -A 5Repository: NVIDIA/TensorRT-LLM
Length of output: 14296
🏁 Script executed:
#!/bin/bash
# Look for the model_engine forward method to see how it uses the model output
rg -n "def forward\(" tensorrt_llm/_torch/pyexecutor/model_engine.py | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 114
🏁 Script executed:
#!/bin/bash
# Check line around 3297-3300 that we saw earlier which calls model.forward
sed -n '3290,3310p' tensorrt_llm/_torch/pyexecutor/model_engine.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1102
🏁 Script executed:
#!/bin/bash
# Check what spec_worker types return
rg -n "class.*Worker.*:" tensorrt_llm/_torch/speculative/ -A 30 | grep -A 25 "class MTPWorker\|class Eagle3OneModelWorker"Repository: NVIDIA/TensorRT-LLM
Length of output: 4316
🏁 Script executed:
#!/bin/bash
# Look for the __call__ method in worker classes
fd "speculative" tensorrt_llm/_torch/ --type f -name "*.py"Repository: NVIDIA/TensorRT-LLM
Length of output: 293
🏁 Script executed:
#!/bin/bash
# Check worker implementations
head -300 tensorrt_llm/_torch/speculative/worker.py | tail -200Repository: NVIDIA/TensorRT-LLM
Length of output: 163
🏁 Script executed:
#!/bin/bash
# Find the return statement in MTPWorker.forward
rg -n "return" tensorrt_llm/_torch/speculative/mtp.py | grep -A 2 "MTPWorker\|def forward" | head -30Repository: NVIDIA/TensorRT-LLM
Length of output: 45
🏁 Script executed:
#!/bin/bash
# Get the full MTPWorker.forward method
sed -n '362,550p' tensorrt_llm/_torch/speculative/mtp.py | tail -100Repository: NVIDIA/TensorRT-LLM
Length of output: 4632
🏁 Script executed:
#!/bin/bash
# Get the full Eagle3OneModelWorker.forward method
sed -n '373,600p' tensorrt_llm/_torch/speculative/eagle3.py | tail -150Repository: NVIDIA/TensorRT-LLM
Length of output: 6246
🏁 Script executed:
#!/bin/bash
# Check if there's any type casting or handling of the return value
sed -n '3295,3320p' tensorrt_llm/_torch/pyexecutor/model_engine.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 992
🏁 Script executed:
#!/bin/bash
# Look at MTPWorker.forward return statement more carefully
rg -n "return {" tensorrt_llm/_torch/speculative/mtp.py | head -5Repository: NVIDIA/TensorRT-LLM
Length of output: 142
🏁 Script executed:
#!/bin/bash
# Get exact return statement from MTPWorker
sed -n '500,530p' tensorrt_llm/_torch/speculative/mtp.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1629
🏁 Script executed:
#!/bin/bash
# Find the actual return statement in MTPWorker.forward
rg -n "return \{" tensorrt_llm/_torch/speculative/mtp.py -A 10Repository: NVIDIA/TensorRT-LLM
Length of output: 1963
🏁 Script executed:
#!/bin/bash
# Look at the end of MTPWorker.forward method
tail -200 tensorrt_llm/_torch/speculative/mtp.py | head -100Repository: NVIDIA/TensorRT-LLM
Length of output: 5501
🏁 Script executed:
#!/antml/bin/bash
# Check what model_forward returns in model_engine
sed -n '3310,3350p' tensorrt_llm/_torch/pyexecutor/model_engine.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1816
🏁 Script executed:
#!/bin/bash
# Get the exact method signature
sed -n '807,820p' tensorrt_llm/_torch/models/modeling_speculative.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 539
Fix return type annotation.
The method's return type is annotated as torch.Tensor, but it can actually return either a torch.Tensor (when is_capturing and not self.enable_cuda_graph_for_draft_model) or a dict (from self.spec_worker). The spec_worker returns a dict with keys like logits, new_tokens, new_tokens_lens, etc., not a tensor. While PyTorchModelEngine._forward_step correctly handles both types using isinstance(outputs, dict), the method's return type annotation should be updated to reflect the actual return types (e.g., Union[torch.Tensor, Dict[str, Any]]).
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_speculative.py around lines 827 to 835,
the method is annotated to return torch.Tensor but actually may return either a
torch.Tensor or a dict from the spec_worker; update the return type annotation
to Union[torch.Tensor, Dict[str, Any]] (or typing.Any as appropriate),
add/import Union, Dict, Any from typing at the top of the file if not present,
and adjust any related type comments or stubs so the signature matches both
possible return types.
| ) or self.model_is_wrapped | ||
| self.max_draft_len = spec_config.max_draft_len | ||
| self.max_total_draft_tokens = spec_config.max_total_draft_tokens | ||
| self.enable_cuda_graph_for_draft_model = spec_config.enable_cuda_graph_for_draft_model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard enable_cuda_graph_for_draft_model and forward_draft usage for non-speculative executors
Right now enable_cuda_graph_for_draft_model is only assigned when spec_config is not None, but the CUDA-graph replay path reads it unconditionally and also assumes inputs['spec_metadata'] exists. For non-speculative executors (no spec_config) that still use CUDA graphs, this can lead to:
AttributeErroronself.enable_cuda_graph_for_draft_model.- Or, once the attribute is initialized, an invalid attempt to call
model.forward_draftwithoutspec_metadataand on models that don’t implement that method.
You likely only intend to run forward_draft in speculative mode when the new flag is False. Suggest initializing the flag for all cases and gating the forward_draft call on self.enable_spec_decode as well.
Proposed fix: initialize flag safely and gate the `forward_draft` call
@@
- self.llm_args = llm_args
- self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0
- self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0
+ self.llm_args = llm_args
+ self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0
+ self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0
@@
- self.spec_config = spec_config
- self.is_spec_decode = spec_config is not None
+ self.spec_config = spec_config
+ self.is_spec_decode = spec_config is not None
+ # Default to True so non-speculative executors never take the draft-only path.
+ self.enable_cuda_graph_for_draft_model = (
+ spec_config.enable_cuda_graph_for_draft_model
+ if spec_config is not None else True
+ )
self.sparse_attention_config = None if is_draft_model else llm_args.sparse_attention_config
self.enable_spec_decode = self.is_spec_decode
self.is_draft_model = is_draft_model
@@
- else:
- with MoeLoadBalancerIterContext(moe_load_balancer):
- outputs = self.cuda_graph_runner.replay(key, inputs)
- if not self.enable_cuda_graph_for_draft_model:
- outputs = self.model.forward_draft(
- outputs, inputs['input_ids'],
- inputs['position_ids'],
- inputs['attn_metadata'],
- inputs['spec_metadata'])
+ else:
+ with MoeLoadBalancerIterContext(moe_load_balancer):
+ outputs = self.cuda_graph_runner.replay(key, inputs)
+ # When speculative decoding is enabled but we opted out of
+ # capturing the draft loop in the CUDA graph, run the
+ # draft-only pass after replay.
+ if (self.enable_spec_decode
+ and not self.enable_cuda_graph_for_draft_model):
+ outputs = self.model.forward_draft(
+ outputs,
+ inputs['input_ids'],
+ inputs['position_ids'],
+ inputs['attn_metadata'],
+ inputs['spec_metadata'],
+ )This keeps non-speculative flows and non-Eagle3 models on the existing path while enabling the new “draft outside CUDA graph” behavior only where spec_config and forward_draft are defined.
Also applies to: 3269-3274
|
PR_Github #29690 [ run ] triggered by Bot. Commit: |
|
PR_Github #29690 [ run ] completed with state
|
Description
Previously, the entire draft forward step was wrapped in
modeling_speculative.py. This is fast but not flexible enough for additional sampling, CPU operation, etc.Now, a flag is introduced to move draft loop out of cuda graph capture.
After this PR, with

enable_cuda_graph_for_draft_model=False, the draft loops are not captured but still queued in kernel runs.Before this PR. The drafting loop is included in cuda graph capture.

The function is not enabled by default to prevent perf loss.
Guided decoder functionality is being fixed.
Guided decoder perf is being fixed.

Summary by CodeRabbit
New Features
enable_cuda_graph_for_draft_modeldefaults toTrue, with validation ensuring compatibility with Eagle3 one-model configuration.Tests
✏️ Tip: You can customize this high-level summary in your review settings.
Test Coverage
tried several combination, all passed accuracy test. Hand-checked acceptance length.
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.