Skip to content

feat & fix: support GQA/MQA and decode-phase attention via IAttentionLayer; add comprehensive HLO-level tests; fix bugs#4246

Open
zewenli98 wants to merge 5 commits into
mainfrom
evanli/hlo-attention-tests
Open

feat & fix: support GQA/MQA and decode-phase attention via IAttentionLayer; add comprehensive HLO-level tests; fix bugs#4246
zewenli98 wants to merge 5 commits into
mainfrom
evanli/hlo-attention-tests

Conversation

@zewenli98
Copy link
Copy Markdown
Collaborator

@zewenli98 zewenli98 commented May 11, 2026

Description

Extends the TRT attention converter to support GQA/MQA and decode-phase attention, and adds a comprehensive HLO-level test suite.

Converter changes (aten_ops_converters.py, force_causal_efficient_attention.py):

  • Lifts the enable_gqa=True rejection in all three SDPA validators (scaled_dot_product_attention, flash, efficient). IAttentionLayer natively handles GQA/MQA — the validator now verifies Hq % Hkv == 0 instead of blocking.
  • Relaxes the shape-equality check to allow decode-phase attention (seq_q != seq_k) and GQA head-count mismatches, while still rejecting incompatible batch/head-dim shapes.
  • Adds attn_bias_is_causal parameter to DispatchTestCase.run_test to control whether the force_causal_efficient_attention lowering pass strips attn_bias before reaching the converter.

New test suite (tests/py/dynamo/hlo/test_attention.py):

  • Covers all three SDPA kernel variants (standard, flash, efficient) across MHA/GQA/MQA, causal/non-causal, bool/float/broadcast masks, decode-phase (seq_q=1), non-power-of-2 head dims, LLM-realistic configs, and fp16/bf16/fp32.
  • Known issues are documented inline (large causal sequences in TRT 10.x and TensorRT-RTX 1.4, fp32 GQA without decompose_attention in PyTorch 2.12.0).

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@zewenli98 zewenli98 requested a review from narendasan May 11, 2026 07:25
@zewenli98 zewenli98 self-assigned this May 11, 2026
@meta-cla meta-cla Bot added the cla signed label May 11, 2026
@github-actions github-actions Bot added component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels May 11, 2026
# this method is only used in our converter test to infer the module output dtypes via dummy inference
# which is due to fx.symbolic_trace does not have the meta['val'] info in the node
# TODO: lan to remove this once our converter test is moved from fx.symbolic_trace to dynamo trace
def infer_module_output_dtypes_for_test(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not need this anymore?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is duplicate as line 42

github-actions[bot]

This comment was marked as resolved.

@zewenli98 zewenli98 force-pushed the evanli/hlo-attention-tests branch from 58b7110 to 430df82 Compare May 11, 2026 18:33
@github-actions github-actions Bot requested review from cehongwang May 12, 2026 17:15
@github-actions github-actions Bot added the component: converters Issues re: Specific op converters label May 12, 2026
@zewenli98 zewenli98 force-pushed the evanli/hlo-attention-tests branch from b0a219b to db3e95a Compare May 12, 2026 19:26
Comment on lines 187 to +207
scale_factor = impl.elementwise.div(
ctx, target, source_ir, f"{name}_div_1_sqrt_q_dim", 1, sqrt_q_dim
)
) # fp32

# For TRT version < 11.0, when seq_len (dim: -2) of k/v >= 512, IAttentionLayer with causal=True returns significantly mismatched results compared to torch.nn.functional.scaled_dot_product_attention
# NVBug: https://nvbugspro.nvidia.com/bug/6047232
if scale_factor.dtype != query.dtype:
if key.shape[-2] >= 512 and is_causal:
if is_tensorrt_version_supported("11.0"):
scale_factor = cast_trt_tensor(
ctx,
scale_factor,
query.dtype,
name + "_cast_scale_factor",
target,
source_ir,
)
else:
_LOGGER.warning(
"For TRT 10.x, when seq_len (dim: -2) of k/v >= 512, IAttentionLayer with causal=True returns significantly mismatched results compared to `torch.nn.functional.scaled_dot_product_attention` in FP16/BF16. Thus, we use FP32 for the scale factor. If you want to use the accurate dtype, please set `decompose_attention=True` or upgrade to TRT 11.0 or later."
)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to the TRT bug, I force the scale_factor here to be fp32 in TRT 10.x.

Comment on lines +9 to +14
TensorRT 10.x (resolved in TRT 11.0) and TensorRT-RTX-1.4:
For TensorRT 10.x, large causal sequences of k/v (seq >= 512, is_causal=True) in FP16/BF16
IAttentionLayer produces ~80% element mismatch at long sequences. Thus, we use FP32 for
the scale factor. If you want to use the accurate dtype, please set `decompose_attention=True`
or upgrade to TRT 11.0 or later. TODO: @Evan to verify the version of TensorRT-RTX that
resolves this bug.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confirming with TRT team that which version of TensorRT-RTX will resolve the bug.

Specific to TensorRT-RTX on Windows, I'm seeing RuntimeError: USE_FLASH_ATTENTION was not enabled for build. so I skipped the flash attention test suite on it.

Ideally, we should support all kinds of attention variants through IAttention, i.e., setting decompose_attention==False for all tests.

Copy link
Copy Markdown
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests needs-release-cherrypick

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants