feat & fix: support GQA/MQA and decode-phase attention via IAttentionLayer; add comprehensive HLO-level tests; fix bugs#4246
Open
zewenli98 wants to merge 5 commits into
Open
feat & fix: support GQA/MQA and decode-phase attention via IAttentionLayer; add comprehensive HLO-level tests; fix bugs#4246zewenli98 wants to merge 5 commits into
zewenli98 wants to merge 5 commits into
Conversation
Closed
7 tasks
narendasan
reviewed
May 11, 2026
| # this method is only used in our converter test to infer the module output dtypes via dummy inference | ||
| # which is due to fx.symbolic_trace does not have the meta['val'] info in the node | ||
| # TODO: lan to remove this once our converter test is moved from fx.symbolic_trace to dynamo trace | ||
| def infer_module_output_dtypes_for_test( |
Collaborator
There was a problem hiding this comment.
Do we not need this anymore?
Collaborator
Author
There was a problem hiding this comment.
This function is duplicate as line 42
58b7110 to
430df82
Compare
b0a219b to
db3e95a
Compare
zewenli98
commented
May 12, 2026
Comment on lines
187
to
+207
| scale_factor = impl.elementwise.div( | ||
| ctx, target, source_ir, f"{name}_div_1_sqrt_q_dim", 1, sqrt_q_dim | ||
| ) | ||
| ) # fp32 | ||
|
|
||
| # For TRT version < 11.0, when seq_len (dim: -2) of k/v >= 512, IAttentionLayer with causal=True returns significantly mismatched results compared to torch.nn.functional.scaled_dot_product_attention | ||
| # NVBug: https://nvbugspro.nvidia.com/bug/6047232 | ||
| if scale_factor.dtype != query.dtype: | ||
| if key.shape[-2] >= 512 and is_causal: | ||
| if is_tensorrt_version_supported("11.0"): | ||
| scale_factor = cast_trt_tensor( | ||
| ctx, | ||
| scale_factor, | ||
| query.dtype, | ||
| name + "_cast_scale_factor", | ||
| target, | ||
| source_ir, | ||
| ) | ||
| else: | ||
| _LOGGER.warning( | ||
| "For TRT 10.x, when seq_len (dim: -2) of k/v >= 512, IAttentionLayer with causal=True returns significantly mismatched results compared to `torch.nn.functional.scaled_dot_product_attention` in FP16/BF16. Thus, we use FP32 for the scale factor. If you want to use the accurate dtype, please set `decompose_attention=True` or upgrade to TRT 11.0 or later." | ||
| ) |
Collaborator
Author
There was a problem hiding this comment.
Due to the TRT bug, I force the scale_factor here to be fp32 in TRT 10.x.
zewenli98
commented
May 13, 2026
Comment on lines
+9
to
+14
| TensorRT 10.x (resolved in TRT 11.0) and TensorRT-RTX-1.4: | ||
| For TensorRT 10.x, large causal sequences of k/v (seq >= 512, is_causal=True) in FP16/BF16 | ||
| IAttentionLayer produces ~80% element mismatch at long sequences. Thus, we use FP32 for | ||
| the scale factor. If you want to use the accurate dtype, please set `decompose_attention=True` | ||
| or upgrade to TRT 11.0 or later. TODO: @Evan to verify the version of TensorRT-RTX that | ||
| resolves this bug. |
Collaborator
Author
There was a problem hiding this comment.
I'm confirming with TRT team that which version of TensorRT-RTX will resolve the bug.
Specific to TensorRT-RTX on Windows, I'm seeing RuntimeError: USE_FLASH_ATTENTION was not enabled for build. so I skipped the flash attention test suite on it.
Ideally, we should support all kinds of attention variants through IAttention, i.e., setting decompose_attention==False for all tests.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Extends the TRT attention converter to support GQA/MQA and decode-phase attention, and adds a comprehensive HLO-level test suite.
Converter changes (aten_ops_converters.py, force_causal_efficient_attention.py):
New test suite (tests/py/dynamo/hlo/test_attention.py):
Type of change
Checklist: