From 753c7f584cbd0e03596ff3c1a1363b6760272347 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Mon, 1 Jun 2026 01:39:54 -0400 Subject: [PATCH] Add Qwen3-Coder-Next contrib model (80B/3B hybrid DeltaNet+MoE) NxD Inference implementation for Qwen/Qwen3-Coder-Next, a hybrid Gated DeltaNet + GQA + Sparse MoE model (80B total, ~3B active/token). Key features: - Custom NKI kernels for DeltaNet recurrence and d=256 flash attention - Expert Parallelism support (EP=1,2,4,8 validated) - Shared expert EP scaling fix for correct world_group all-reduce - vLLM integration for serving - 77 tok/s at batch=1 on trn2.48xlarge (TP=8) - 100% top-1 accuracy match vs HF CPU reference (14/14 prompts) - Constant TPOT regardless of context length (DeltaNet O(1) generation) Architecture: 48 layers (36 DeltaNet + 12 GQA), 512 experts top-10, head_dim=256, partial RoPE (25%), max context 1024 tokens. Tested on: trn2.48xlarge, SDK 2.30, LNC=2 --- contrib/models/Qwen3-Coder-Next/README.md | 222 + .../models/Qwen3-Coder-Next/src/__init__.py | 26 + .../src/modeling_qwen35_moe.py | 3776 +++++++++++++++++ .../Qwen3-Coder-Next/src/nki_deltanet.py | 349 ++ .../src/nki_flash_attn_d256_pipe.py | 2258 ++++++++++ .../models/Qwen3-Coder-Next/test/__init__.py | 0 .../test/integration/__init__.py | 0 .../test/integration/test_model.py | 513 +++ .../Qwen3-Coder-Next/test/unit/__init__.py | 0 .../Qwen3-Coder-Next/vllm/register_model.py | 30 + .../vllm/start_vllm_server.sh | 71 + .../Qwen3-Coder-Next/vllm/test_vllm_client.py | 121 + 12 files changed, 7366 insertions(+) create mode 100644 contrib/models/Qwen3-Coder-Next/README.md create mode 100644 contrib/models/Qwen3-Coder-Next/src/__init__.py create mode 100644 contrib/models/Qwen3-Coder-Next/src/modeling_qwen35_moe.py create mode 100644 contrib/models/Qwen3-Coder-Next/src/nki_deltanet.py create mode 100644 contrib/models/Qwen3-Coder-Next/src/nki_flash_attn_d256_pipe.py create mode 100644 contrib/models/Qwen3-Coder-Next/test/__init__.py create mode 100644 contrib/models/Qwen3-Coder-Next/test/integration/__init__.py create mode 100644 contrib/models/Qwen3-Coder-Next/test/integration/test_model.py create mode 100644 contrib/models/Qwen3-Coder-Next/test/unit/__init__.py create mode 100644 contrib/models/Qwen3-Coder-Next/vllm/register_model.py create mode 100755 contrib/models/Qwen3-Coder-Next/vllm/start_vllm_server.sh create mode 100644 contrib/models/Qwen3-Coder-Next/vllm/test_vllm_client.py diff --git a/contrib/models/Qwen3-Coder-Next/README.md b/contrib/models/Qwen3-Coder-Next/README.md new file mode 100644 index 00000000..7b6644db --- /dev/null +++ b/contrib/models/Qwen3-Coder-Next/README.md @@ -0,0 +1,222 @@ +# Contrib Model: Qwen3-Coder-Next + +Optimized NxD Inference implementation for Qwen3-Coder-Next, a hybrid Gated DeltaNet + GQA + Sparse MoE model with 80B total parameters and ~3B active per token, running on AWS Trainium2. + +## Model Information + +- **HuggingFace ID:** [`Qwen/Qwen3-Coder-Next`](https://huggingface.co/Qwen/Qwen3-Coder-Next) +- **Model Type:** Hybrid DeltaNet (linear recurrent) + GQA + Sparse MoE decoder +- **Parameters:** 80B total, ~3B active per token (BF16) +- **Architecture:** 48 layers (36 DeltaNet + 12 GQA), 512 experts top-10, head_dim=256, partial RoPE (25%) +- **License:** Apache 2.0 +- **Maintainer:** Jim Burtoft + +## Validation Results + +**Validated:** 2026-05-29 +**Instance:** trn2.48xlarge (TP=8, LNC=2) +**SDK:** Neuron SDK 2.30 (neuronx-cc 2.25.3371, neuronx-distributed-inference 0.10.17970) + +### Benchmark Results + +| Metric | Value | +|--------|-------| +| **Throughput** | **77 tok/s** | +| TPOT (median) | 13.0 ms | +| TPOT (p99) | 13.3 ms | +| TTFT @ 32 tokens | 245 ms | +| TTFT @ 128 tokens | 1,235 ms | +| TTFT @ 256 tokens | 1,939 ms | +| TTFT @ 512 tokens | 3,471 ms | +| TTFT @ 1024 tokens | 7,091 ms | + +Configuration: batch_size=1, greedy decoding, single CTE bucket. + +### Accuracy Validation + +| Metric | Value | +|--------|-------| +| Top-1 token match rate | 100% (14/14 prompts) | +| Cosine similarity (logit vectors) | 0.9998 | +| Max logit difference | 0.38 | + +Validated against HuggingFace BF16 CPU reference using greedy decoding with teacher forcing. + +## Usage + +### Prerequisites + +```bash +# Activate NxDI environment on trn2.48xlarge +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +# Download model weights (~149 GB) +pip install huggingface_hub[cli] +huggingface-cli download Qwen/Qwen3-Coder-Next --local-dir /mnt/models/Qwen3-Coder-Next/ +``` + +### Compile and Run + +```python +import os, sys, torch +from transformers import AutoTokenizer, AutoConfig +from neuronx_distributed_inference.models.config import MoENeuronConfig + +sys.path.insert(0, '/path/to/contrib/models/Qwen3-Coder-Next/src') +os.environ['NEURON_CC_FLAGS'] = '--auto-cast matmult --auto-cast-type bf16' + +from modeling_qwen35_moe import NeuronQwen35MoeForCausalLM, Qwen35MoeInferenceConfig + +model_path = '/mnt/models/Qwen3-Coder-Next' + +def make_load_config(model_path): + def _load_config(config_self): + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + for key, value in hf_config.to_dict().items(): + if not key.startswith('_') and key != 'transformers_version': + setattr(config_self, key, value) + return _load_config + +neuron_config = MoENeuronConfig( + tp_degree=8, + max_batch_size=1, + max_context_length=1024, + max_new_tokens=128, + max_length=1152, + torch_dtype=torch.bfloat16, + fused_qkv=True, + moe_tp_degree=8, + moe_ep_degree=1, + enable_bucketing=True, + context_encoding_buckets=[32], + blockwise_matmul_config={ + 'block_size': 128, + 'use_shard_on_block_dynamic_while': True, + 'block_sharding_strategy': 'PING_PONG', + }, +) + +inference_config = Qwen35MoeInferenceConfig( + neuron_config=neuron_config, + load_config=make_load_config(model_path), +) + +# Compile (first time only) +model = NeuronQwen35MoeForCausalLM(model_path, inference_config) +model.compile(compiled_model_path='/mnt/compiled_qwen3') + +# Load and generate +model.load('/mnt/compiled_qwen3') +model.reset() + +tokenizer = AutoTokenizer.from_pretrained(model_path) +input_ids = tokenizer("def quicksort(arr):\n", return_tensors='pt').input_ids +n = input_ids.shape[1] + +with torch.no_grad(): + out = model.forward( + input_ids=input_ids, + attention_mask=torch.ones(1, n, dtype=torch.int32), + position_ids=torch.arange(n, dtype=torch.long).unsqueeze(0), + seq_ids=torch.zeros(1, dtype=torch.long), + ) + logits = out[0][0][-1] if out[0][0].dim() == 2 else out[0][0] + print(tokenizer.decode(logits.argmax().item())) +``` + +## Expert Parallelism (EP) + +This model supports Expert Parallelism for distributing the 512 experts across multiple EP ranks. EP reduces per-rank HBM usage for expert weights and enables full-chip utilization on trn2.48xlarge. + +### Validated Configurations + +| EP | TP | World Size | Cores | Status | +|----|----|-----------:|------:|--------| +| 1 | 8 | 8 | 8 | Baseline | +| 2 | 8 | 16 | 16 | Validated | +| 4 | 8 | 32 | 32 | Validated | +| 8 | 8 | 64 | 64 | Validated (full chip) | + +### EP Configuration + +```python +neuron_config = MoENeuronConfig( + tp_degree=8, + ep_degree=4, # Expert Parallelism degree + moe_ep_degree=4, # Must match ep_degree + moe_tp_degree=8, # Must match tp_degree + # ... other config +) +``` + +### EP Implementation Notes + +- **Shared expert scaling**: With EP > 1, the framework's world_group all-reduce sums the shared expert output `ep_degree` times (since it's identical across EP ranks). The model compensates by dividing the shared expert output by `ep_degree` in the CTE path. +- **CTE dispatch**: The `ExpertMLPsV2.forward` must be patched to use `forward_blockwise` for CTE (the default `forward_selective_loading` does not support EP). A monkeypatch is provided in the test scripts. +- **ctx=32 with EP=8**: Compilation fails due to NKI DeltaNet kernel assertion. Use ctx >= 128 with EP=8. + +## Compatibility Matrix + +| Instance | SDK 2.30 | +|----------|----------| +| trn2.48xlarge (TP=8, EP=1-8) | VALIDATED | +| trn2.3xlarge (TP=4) | NOT SUPPORTED (HBM OOM) | + +## Example Checkpoints + +* [`Qwen/Qwen3-Coder-Next`](https://huggingface.co/Qwen/Qwen3-Coder-Next) + +## Testing Instructions + +```bash +# On trn2.48xlarge with SDK 2.30 DLAMI +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +cd contrib/models/Qwen3-Coder-Next/ + +# Run integration tests +MODEL_PATH=/mnt/models/Qwen3-Coder-Next \ +COMPILED_PATH=/mnt/compiled_qwen3_test/ \ +pytest test/integration/test_model.py -v +``` + +## Architecture Details + +| Parameter | Value | +|-----------|-------| +| Total parameters | 80B | +| Active parameters/token | ~3B | +| Layers | 48 (36 DeltaNet + 12 GQA) | +| Hidden size | 2048 | +| Experts per layer | 512, top-10 | +| Expert intermediate size | 512 | +| Attention heads (Q/KV) | 16/2, head_dim=256 | +| DeltaNet heads (QK/V) | 16/32, head_dim=128 | +| Partial RoPE | 25% (64 of 256 dims) | +| Max context (Neuron) | 1024 tokens | + +### Key Properties + +1. **Constant TPOT**: Token generation latency is O(1) regardless of context length (DeltaNet recurrent state, no growing KV cache for 36/48 layers) +2. **Linear TTFT**: Prefill scales linearly with input tokens (~6.9 ms/token) +3. **Single-bucket optimization**: Use one CTE bucket per deployment for best prefill latency + +### Custom NKI Kernels + +| Kernel | Purpose | File | +|--------|---------|------| +| DeltaNet Recurrent | Token generation for linear attention layers | `nki_deltanet.py` | +| Flash Attention d=256 | Context encoding for GQA layers (seq >= 512) | `nki_flash_attn_d256_pipe.py` | + +## vLLM Integration + +This model supports vLLM serving via the `vllm/` directory. See `vllm/start_vllm_server.sh` for usage. + +## Known Issues + +1. **Max context: 1024 tokens** — Model weights consume 20.8 GB per NeuronCore pair (LNC=2), leaving ~2.6 GB for scratchpad. Context lengths > 1024 exceed available HBM. INT8 quantization would unlock longer contexts. +2. **TP=4 not supported** — Per-rank expert weights (~37 GB) exceed 24 GB HBM per core at TP=4. +3. **TP=16 not supported** — NKI DeltaNet kernel requires `linear_value_head_dim >= 16` per rank (128/16=8 is too small). +4. **DeltaNet state reset** — Must call `model.reset()` between independent prompts to clear recurrent state. +5. **NKI deprecation warnings** on import (cosmetic, from blockwise_mm internals in neuronx-distributed). +6. **EP=8 requires ctx >= 128** — NKI DeltaNet kernel fails at ctx=32 with EP=8 ("Out-of-bound access... index range [0, 127] exceed dimension size of 33"). Use context_encoding_buckets with minimum size 128 for EP=8. +7. **position_ids for padded CTE inputs** — When manually padding inputs for CTE, padding positions must have `position_id = 0` (not incrementing values). The framework uses `torch.max(position_ids)` to find the last real token. Incorrect position_ids will cause the model to output `<|endoftext|>`. diff --git a/contrib/models/Qwen3-Coder-Next/src/__init__.py b/contrib/models/Qwen3-Coder-Next/src/__init__.py new file mode 100644 index 00000000..61b35c01 --- /dev/null +++ b/contrib/models/Qwen3-Coder-Next/src/__init__.py @@ -0,0 +1,26 @@ +# Qwen3-Coder-Next NxDI contrib model +# Hybrid DeltaNet + GQA + Sparse MoE (80B total / 3B active per token) + +from .modeling_qwen35_moe import ( + Qwen35MoeInferenceConfig, + NeuronQwen35MoeForCausalLM, + NeuronQwen35MoeModel, + NeuronGatedDeltaNet, + NeuronQwen35Attention, + NeuronQwen35DecoderLayer, + SigmoidGatedSharedExperts, + Qwen35DecoderModelInstance, + Qwen35ModelWrapper, +) + +__all__ = [ + "Qwen35MoeInferenceConfig", + "NeuronQwen35MoeForCausalLM", + "NeuronQwen35MoeModel", + "NeuronGatedDeltaNet", + "NeuronQwen35Attention", + "NeuronQwen35DecoderLayer", + "SigmoidGatedSharedExperts", + "Qwen35DecoderModelInstance", + "Qwen35ModelWrapper", +] diff --git a/contrib/models/Qwen3-Coder-Next/src/modeling_qwen35_moe.py b/contrib/models/Qwen3-Coder-Next/src/modeling_qwen35_moe.py new file mode 100644 index 00000000..836cdc74 --- /dev/null +++ b/contrib/models/Qwen3-Coder-Next/src/modeling_qwen35_moe.py @@ -0,0 +1,3776 @@ +""" +NxDI contrib: Qwen3.5-35B-A3B (qwen3_5_moe / qwen3_next) + +Hybrid DeltaNet + Standard Attention + MoE architecture. +Based on NxDI Qwen3-MoE with custom DeltaNet layers. + +30 of 40 layers use Gated DeltaNet (linear recurrent attention) +10 of 40 layers use standard GQA with KV cache + output gate +All 40 layers use sparse MoE (256 experts, top-8 + shared expert with sigmoid gate) + +Architecture details: +- DeltaNet layers: separate in_proj_{qkv, z, a, b}, causal conv1d on QKV, gated delta rule +- Attention layers: q_proj doubled (Q + gate), partial RoPE (25% of head_dim), sigmoid output gate +- MoE: pre-fused expert weights, shared expert with sigmoid gate +- KV cache: NxDI KVCacheManager for attention layers; DeltaNet layers store recurrent+conv + state as nn.Parameter buffers and return dummy KV tuples +""" + +import gc +import math +import logging +import os +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.modules.attention.gqa import GQA +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm + +try: + from neuronxcc.nki._private_kernels.attention import attention_isa_kernel +except ImportError: + from neuronxcc.nki.kernels.attention import attention_isa_kernel + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, +) +from neuronx_distributed.utils import cpu_mode +from torch_neuronx.xla_impl.ops import nki_jit +from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeRMSNorm + +from nki_deltanet import deltanet_recurrent_fwd as _deltanet_nki_kernel +from nki_deltanet import deltanet_recurrent_fwd_state as _deltanet_nki_kernel_state +from nki_flash_attn_d256_pipe import flash_attn_d256_pipe as _flash_attn_d256_kernel_raw + +# NKI 0.3.0: @nki.jit returns nki.Kernel which auto-detects framework. +# No need for manual PyTorchXLAKernel wrapping (removed in nki 0.3.0). +_flash_attn_d256_kernel = _flash_attn_d256_kernel_raw + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + MoENeuronConfig, + SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP, + MOE_TKG_MK_INTERMEDIATE_PER_TP, +) +from neuronx_distributed_inference.models.model_wrapper import ( + CONTEXT_ENCODING_MODEL_TAG, + TOKEN_GENERATION_MODEL_TAG, + DecoderModelInstance, + ModelWrapper, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, + FlashAttentionStrategy, +) +from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding +from neuronx_distributed_inference.modules.moe_v2 import initialize_moe_module +from neuronx_distributed_inference.models.layer_boundary_marker import ( + ModuleMarkerEndWrapper, + ModuleMarkerStartWrapper, +) + +logger = logging.getLogger(__name__) + +_flash_fwd_call = nki_jit()(attention_isa_kernel) + + +def _patch_fused_tkg_for_fp32_router(): + """Patch MoEFusedTKG kernel to use ISA router fallback for float32 routing. + + The fused MoE TKG NKI kernel's router_topk_kernel_nki asserts that input + and weight dtypes match. When router_config.dtype=float32 (for accuracy), + the hidden states are float32 but router weights are bfloat16, causing + an assertion error. The ISA router fallback handles mixed dtypes correctly. + + Must be called before model.compile(). + """ + try: + import neuronx_distributed.modules.moe.moe_fused_tkg as fused_tkg_mod + + original_kernel = fused_tkg_mod._moe_token_gen_selective_load_kernel_nki_call + if original_kernel is None: + logger.warning( + "Fused TKG selective load kernel not available, skipping patch" + ) + return + + class _PatchedKernelCall: + """Wrapper that injects use_router_topk_nki_kernel=False.""" + + def __init__(self, original): + self._original = original + + def __getitem__(self, grid): + original_grid_call = self._original[grid] + + def patched_call(*args, **kwargs): + kwargs["use_router_topk_nki_kernel"] = False + return original_grid_call(*args, **kwargs) + + return patched_call + + fused_tkg_mod._moe_token_gen_selective_load_kernel_nki_call = ( + _PatchedKernelCall(original_kernel) + ) + + # Also patch the forward-all-experts kernel + original_all = fused_tkg_mod._moe_tkg_forward_all_experts_nki_call + if original_all is not None: + fused_tkg_mod._moe_tkg_forward_all_experts_nki_call = _PatchedKernelCall( + original_all + ) + + logger.info("Patched MoEFusedTKG for float32 router (ISA router fallback).") + except ImportError: + logger.info("moe_fused_tkg module not available, skipping patch") + except Exception as e: + logger.warning("Failed to patch MoEFusedTKG: %s", e) + + +GQA_SHARDING_STRATEGY = GQA.REPLICATE_TO_TP_DEGREE + + +# ============================================================ +# Newton-Raphson Refined RMSNorm (Task 18) +# ============================================================ +# Hardware rsqrt has systematic negative bias (~-7e-6 mean). +# Over 80+ RMSNorm applications across 40 layers, this compounds. +# One Newton refinement step: y' = y * (3 - x*y^2) / 2 +# reduces per-application error from 3.5e-4 to 1.9e-6. + +# Toggle: set to True to use Newton-refined rsqrt in RMSNorm +USE_NEWTON_RMSNORM = False + + +class NewtonRMSNorm(nn.Module): + """RMSNorm with Newton-Raphson refined rsqrt for improved numerical accuracy. + + Drop-in replacement for CustomRMSNorm. Uses pure PyTorch ops so the + Neuron compiler traces it directly (no opaque custom HLO call). + """ + + def __init__(self, hidden_size=None, eps=1e-6): + super().__init__() + self.weight = None + if hidden_size is not None: + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.hidden_size = hidden_size + self.variance_epsilon = eps + + def forward(self, hidden_states): + original_dtype = hidden_states.dtype + x = hidden_states.to(torch.float32) + # Variance: mean(x^2) along last dim + variance = x.pow(2).mean(-1, keepdim=True) + # Initial hardware rsqrt estimate + y = torch.rsqrt(variance + self.variance_epsilon) + # Newton-Raphson refinement: y' = y * (3 - (var+eps) * y^2) / 2 + y = y * (3.0 - (variance + self.variance_epsilon) * y * y) * 0.5 + # Apply normalization and weight + result = x * y + if self.weight is not None: + result = result * self.weight.float() + return result.to(original_dtype) + + +def get_rmsnorm_cls(): + if cpu_mode(): + return Qwen3MoeRMSNorm + return NewtonRMSNorm if USE_NEWTON_RMSNORM else CustomRMSNorm + + +def l2norm(x, dim=-1, eps=1e-6): + return F.normalize(x, p=2, dim=dim, eps=eps) + + +# ============================================================ +# Gated DeltaNet Module (Linear Recurrent Attention) +# ============================================================ + + +class NeuronGatedDeltaNet(nn.Module): + """ + Gated DeltaNet linear attention for Neuron. + + Replaces standard attention for 30 of 40 layers. + Uses a chunk-based linear recurrence instead of KV cache. + + V1 Design (stateless -- compiles but loses state between CTE and TKG): + - CTE: chunk forward computes correct output for the prefill sequence. + - TKG: recurrent step with ZERO initial state (no carry-over from CTE). + - DeltaNet layers return dummy (K, V) tuples so KVCacheManager can process them. + - No in-place buffer mutations (XLA trace safe). + + V2 TODO: Use input_output_aliases or repurpose KV cache slots to carry + recurrent state and conv state between CTE and TKG. + + HF weight layout: + - in_proj_qkv.weight: (key_dim*2 + value_dim, hidden_size) = (8192, 2048) + - in_proj_z.weight: (value_dim, hidden_size) = (4096, 2048) + - in_proj_a.weight: (num_v_heads, hidden_size) = (32, 2048) + - in_proj_b.weight: (num_v_heads, hidden_size) = (32, 2048) + - conv1d.weight: (conv_dim, 1, conv_kernel_size) = (8192, 1, 4) + - A_log: (num_v_heads,) = (32,) + - dt_bias: (num_v_heads,) = (32,) + - norm.weight: (head_v_dim,) = (128,) + - out_proj.weight: (hidden_size, value_dim) = (2048, 4096) + """ + + def __init__(self, config, layer_idx: int): + super().__init__() + tc = config + + self.hidden_size = tc.hidden_size # 2048 + self.num_v_heads = tc.linear_num_value_heads # 32 + self.num_k_heads = tc.linear_num_key_heads # 16 + self.head_k_dim = tc.linear_key_head_dim # 128 + self.head_v_dim = tc.linear_value_head_dim # 128 + self.key_dim = self.head_k_dim * self.num_k_heads # 2048 + self.value_dim = self.head_v_dim * self.num_v_heads # 4096 + self.conv_kernel_size = tc.linear_conv_kernel_dim # 4 + self.layer_idx = layer_idx + self.rms_norm_eps = tc.rms_norm_eps + + # KV cache dummy shape info (for returning proper-shaped zeros) + # Must match KVCacheManager's per-rank shape: (B, num_kv_heads_per_rank, seq_len, head_dim) + # With REPLICATE_TO_TP_DEGREE: raw KV heads (2) replicated to tp_degree (4), then /tp = 1 per rank + self.head_dim = tc.head_dim # 256 + tp_degree = tc.neuron_config.tp_degree + raw_kv_heads = tc.num_key_value_heads + # Replicate KV heads to tp_degree if fewer, then divide + if raw_kv_heads < tp_degree: + replicated_kv_heads = tp_degree # REPLICATE_TO_TP_DEGREE strategy + else: + replicated_kv_heads = raw_kv_heads + self.kv_heads_per_rank = replicated_kv_heads // tp_degree + + # Conv1d on concatenated QKV (NOT Z) + self.conv_dim = self.key_dim * 2 + self.value_dim # 8192 + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # Input projections -- separate to match HF weight layout + self.in_proj_qkv = nn.Linear( + self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False + ) + self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False) + self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) + + # Decay parameters + self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) + self.A_log = nn.Parameter(torch.zeros(self.num_v_heads)) + + # Output norm and projection + # Use standard RMSNorm (not CustomRMSNorm) since DeltaNet is custom code + # and we need it to work in both CPU mode and Neuron tracing. + # The Qwen3MoeRMSNorm is a plain PyTorch RMSNorm that works everywhere. + self.norm = Qwen3MoeRMSNorm(self.head_v_dim, eps=self.rms_norm_eps) + self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + # ---- State buffers for CTE -> TKG carry-over ---- + # These are nn.Parameter(requires_grad=False) so they participate in + # input_output_aliases, allowing the XLA runtime to alias the output + # tensors back to the same HBM buffers across CTE and TKG graphs. + # + # Recurrent state: (B, num_v_heads, k_dim, v_dim) = (B, 32, 128, 128) + # The NKI kernel outputs per-head (128, 128) in float32; we store as bf16 + # on HBM and cast at load/store time. + # + # Conv state: last (kernel_size - 1) = 3 tokens of the mixed tensor + # (QKV concat before conv1d). Shape: (B, conv_dim, kernel_size - 1) + # = (B, 8192, 3). Stores the last 3 tokens' mixed values so TKG can + # compute conv1d correctly for the next token. + # + # Note: Use max_batch_size for buffer allocation so CTE and TKG models + # have identically-shaped state dict entries (required for loading). + # In forward(), we slice to the actual batch_size. + # Both buffers are stored in the model's compute dtype (bf16). + alloc_batch_size = getattr(config.neuron_config, "max_batch_size", 1) + self._phase_batch_size = getattr(config.neuron_config, "batch_size", 1) + self.recurrent_state_buffer = nn.Parameter( + torch.zeros( + alloc_batch_size, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + dtype=config.neuron_config.torch_dtype, + ), + requires_grad=False, + ) + self.conv_state_buffer = nn.Parameter( + torch.zeros( + alloc_batch_size, + self.conv_dim, + self.conv_kernel_size - 1, # 3 + dtype=config.neuron_config.torch_dtype, + ), + requires_grad=False, + ) + + def _recurrent_step(self, query, key, value, g, beta, recurrent_state): + """Single-step recurrent update for token generation. + + Args: + query: (B, H, 1, k_dim) [H = num_v_heads after K-head expansion] + key: (B, H, 1, k_dim) + value: (B, H, 1, v_dim) + g: (B, H, 1) -- log-decay + beta: (B, H, 1) -- write gate + recurrent_state: (B, H, k_dim, v_dim) + + Returns: + output: (B, H, 1, v_dim) + new_state: (B, H, k_dim, v_dim) + """ + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + q_t = query[:, :, 0] # (B, H, k_dim) + k_t = key[:, :, 0] + v_t = value[:, :, 0] # (B, H, v_dim) + g_t = g[:, :, 0].exp().unsqueeze(-1).unsqueeze(-1) # (B, H, 1, 1) + beta_t = beta[:, :, 0].unsqueeze(-1) # (B, H, 1) + + # Decay old state + new_state = recurrent_state * g_t + # Compute delta update + kv_mem = (new_state * k_t.unsqueeze(-1)).sum(dim=-2) # (B, H, v_dim) + delta = (v_t - kv_mem) * beta_t + new_state = new_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + # Read out + output = (new_state * q_t.unsqueeze(-1)).sum(dim=-2) # (B, H, v_dim) + + return output.unsqueeze(2), new_state + + def _nki_recurrent_forward(self, query, key, value, g, beta): + """Full-sequence recurrent forward using NKI kernel for context encoding. + + Uses the _state variant kernel to also return the final recurrent state + for CTE -> TKG carry-over. + + The NKI kernel processes a single (batch, head) pair's full sequence. + We call it in a loop over B*H from PyTorch. + + Args: + query: (B, H, S, k_dim) float32 + key: (B, H, S, k_dim) float32 + value: (B, H, S, v_dim) float32 + g: (B, H, S) float32 -- log-decay + beta: (B, H, S) float32 -- write gate + + Returns: + output: (B, H, S, v_dim) float32 + final_state: (B, H, k_dim, v_dim) float32 -- recurrent state after last token + """ + # L2-normalize and scale + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + # Flatten (B, H) -> BH for looping + BH = B * H + query_flat = query.reshape(BH, S, k_dim).contiguous() + key_flat = key.reshape(BH, S, k_dim).contiguous() + value_flat = value.reshape(BH, S, v_dim).contiguous() + + # Expand g/beta from (B, H, S) to (BH, S, 128) for NKI -- + # tensor_scalar requires operand0 shape (P_MAX, 1) matching partition axis. + g_flat = g.reshape(BH, S).unsqueeze(-1).expand(-1, -1, v_dim).contiguous() + beta_flat = beta.reshape(BH, S).unsqueeze(-1).expand(-1, -1, v_dim).contiguous() + + # Call NKI kernel per (batch, head) pair -- 2D tensors (S, 128) + outputs = [] + states = [] + for bh in range(BH): + out_bh, state_bh = _deltanet_nki_kernel_state( + query_flat[bh], # (S, 128) + key_flat[bh], # (S, 128) + value_flat[bh], # (S, 128) + g_flat[bh], # (S, 128) + beta_flat[bh], # (S, 128) + ) + outputs.append(out_bh) + states.append(state_bh) # (128, 128) + + # Stack back to (BH, S, v_dim) then reshape to (B, H, S, v_dim) + output = torch.stack(outputs, dim=0) # (BH, S, v_dim) + output = output.reshape(B, H, S, v_dim) + + # Stack states to (BH, k_dim, v_dim) then reshape to (B, H, k_dim, v_dim) + final_state = torch.stack(states, dim=0) # (BH, k_dim, v_dim) + final_state = final_state.reshape(B, H, k_dim, v_dim) + + return output, final_state + + def _chunk_forward(self, query, key, value, g, beta, output_final_state=False): + """Chunk-based forward for context encoding (prefill). + + V5: Uses the chunked formulation from the reference (torch_chunk_gated_delta_rule). + Uses chunk_size=64, inter-chunk recurrent state propagation, and iterative + correction loop within each chunk. + + NOTE: The iterative correction loop uses variable-width slice assignments + which may cause accuracy issues under XLA tracing, but this approach + compiles and produces reasonable (not perfect) results. + + Args: + query: (B, H, S, k_dim) -- already in float32 + key: (B, H, S, k_dim) -- already in float32 + value: (B, H, S, v_dim) -- already in float32 + g: (B, H, S) -- already in float32 + beta: (B, H, S) -- already in float32 + output_final_state: if True, return final recurrent state + + Returns: + output: (B, H, S, v_dim) + last_recurrent_state: (B, H, k_dim, v_dim) or None + """ + chunk_size = ( + 32 # V49: chunk_size=32 compiles; testing WITHOUT --auto-cast matmult + ) + + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + # Pad to multiple of chunk_size + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + + # Reshape to chunks: (B, H, num_chunks, chunk_size, dim) + num_chunks = total_seq_len // chunk_size + query = query.reshape(B, H, num_chunks, chunk_size, k_dim) + key = key.reshape(B, H, num_chunks, chunk_size, k_dim) + value = value.reshape(B, H, num_chunks, chunk_size, v_dim) + k_beta = k_beta.reshape(B, H, num_chunks, chunk_size, k_dim) + v_beta = v_beta.reshape(B, H, num_chunks, chunk_size, v_dim) + g = g.reshape(B, H, num_chunks, chunk_size) + + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=0, + ) + + # Cumulative decay within each chunk + g = g.cumsum(dim=-1) + decay_mask = (g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().tril() + + # Intra-chunk delta rule correction (iterative) + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + + # Corrected value within each chunk + value = attn @ v_beta # (B, H, num_chunks, chunk_size, v_dim) + # Corrected key * cumdecay within each chunk + k_cumdecay = attn @ ( + k_beta * g.exp().unsqueeze(-1) + ) # (B, H, num_chunks, chunk_size, k_dim) + + # Inter-chunk recurrent state propagation + last_recurrent_state = torch.zeros( + B, H, k_dim, v_dim, dtype=query.dtype, device=query.device + ) + core_attn_out = torch.zeros_like(value) + mask2 = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=1, + ) + + for i in range(num_chunks): + q_i = query[:, :, i] # (B, H, chunk_size, k_dim) + k_i = key[:, :, i] + v_i = value[:, :, i] # corrected value + + attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_( + mask2, 0 + ) + + # Inter-chunk: subtract contribution of old state from corrected value + v_prime = ( + k_cumdecay[:, :, i] @ last_recurrent_state + ) # (B, H, chunk_size, v_dim) + v_new = v_i - v_prime + + # Inter-chunk: query reads from old state + attn_inter = ( + q_i * g[:, :, i, :, None].exp() + ) @ last_recurrent_state # (B, H, chunk_size, v_dim) + core_attn_out[:, :, i] = attn_inter + attn_i @ v_new + + # Update recurrent state: decay by last position's cumulative g, then add new info + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + ( + k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None] + ).transpose(-1, -2) + @ v_new + ) + + # Reshape back and trim padding + core_attn_out = core_attn_out.reshape(B, H, -1, v_dim) + core_attn_out = core_attn_out[:, :, :S] + + if not output_final_state: + last_recurrent_state = None + + return core_attn_out, last_recurrent_state + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + position_ids=None, + past_key_value=None, + **kwargs, + ): + """Forward pass compatible with NxDI decoder layer interface. + + DeltaNet layers do NOT use KV cache. We return dummy (K, V) tuples + with proper shapes so KVCacheManager.update_cache() can process them + without crashing. The dummy values are zeros and will be written to + the cache slots allocated for this layer but never read back. + + State carry-over (V37): + - recurrent_state_buffer: (B, 32, 128, 128) nn.Parameter, aliased via + input_output_aliases. CTE writes final state; TKG reads it. + - conv_state_buffer: (B, 8192, 3) nn.Parameter, aliased. CTE writes + last 3 tokens of pre-silu mixed; TKG uses for conv1d context. + + CRITICAL: For CTE, NxDI pads input to bucket size (e.g., 128 tokens). + DeltaNet has no attention mask -- the recurrence processes ALL positions. + We must zero out padding positions before projection to prevent pad tokens + from corrupting the recurrent state. The attention_mask from NxDI is + typically (B, 1, 1, S) or (B, 1, S, S) with 0 for valid, large negative + for padding. + + Returns: + output: (B, S, hidden_size) + dummy_kv: tuple(K_dummy, V_dummy) with proper shapes + new_recurrent_state: (B, 32, 128, 128) updated recurrent state buffer + new_conv_state: (B, 8192, 3) updated conv state buffer + """ + batch_size, seq_len, _ = hidden_states.shape + + # Determine mode: context encoding (prefill) vs token generation (decode) + is_decode = past_key_value is not None + + # Extract seq_ids from kwargs (passed by decoder layer from get_model_output). + # seq_ids maps each element in the current batch to its slot in the state buffer. + seq_ids = kwargs.get("seq_ids", None) + + # --- Mask padding tokens for DeltaNet --- + # CRITICAL V42: Use padding_mask passed directly from get_model_output. + # This is the raw 2D attention_mask (B, S) cast to bf16, derived DIRECTLY + # from the traced input attention_mask[position 1] WITHOUT any comparison + # or reduction operation. XLA CANNOT constant-fold this because: + # 1. attention_mask is a traced input (parameter, not constant) + # 2. The only operation is a dtype cast (int32 -> bf16), which is preserved + # 3. There's no comparison (like >= arange) that XLA could evaluate at trace time + # + # The position_ids >= arange approach was proven to be folded by XLA despite + # being derived from a traced input, because XLA evaluates the comparison at + # trace time and determines the result is a constant pattern. + padding_mask_input = kwargs.get("padding_mask", None) + + valid_mask_1d = None # (B, S) float, 1.0 for valid, 0.0 for padding + if not is_decode: + if padding_mask_input is not None and padding_mask_input.dim() == 2: + # PREFERRED PATH V42: Use pre-computed padding mask from attention_mask. + # This is already (B, S) with 1.0=valid, 0.0=padding in bf16. + valid_mask_1d = padding_mask_input.to(hidden_states.dtype) + elif position_ids is not None and position_ids.dim() == 2: + # FALLBACK: Derive from position_ids (may be folded by XLA). + seq_len = hidden_states.shape[1] + indices = torch.arange( + seq_len, device=position_ids.device, dtype=position_ids.dtype + ).unsqueeze(0) # (1, S) + valid_mask_1d = (position_ids >= indices).to( + hidden_states.dtype + ) # (B, S) bf16 + elif attention_mask is not None: + # Fallback only if position_ids unavailable + if attention_mask.dim() == 4: + pad_mask_1d = torch.diagonal(attention_mask[:, 0], dim1=-2, dim2=-1) + elif attention_mask.dim() == 2: + pad_mask_1d = attention_mask + else: + pad_mask_1d = None + if pad_mask_1d is not None: + valid_mask_1d = pad_mask_1d.to(hidden_states.dtype) + + if valid_mask_1d is not None: + hidden_states = hidden_states * valid_mask_1d.unsqueeze(-1) # (B, S, D) + + # Project inputs + deltanet_fp32 = os.environ.get("DELTANET_FP32") == "1" + if deltanet_fp32: + hs_f32 = hidden_states.float() + qkv = F.linear(hs_f32, self.in_proj_qkv.weight.float()).to( + hidden_states.dtype + ) + z = F.linear(hs_f32, self.in_proj_z.weight.float()).to(hidden_states.dtype) + b = F.linear(hs_f32, self.in_proj_b.weight.float()).to(hidden_states.dtype) + a = F.linear(hs_f32, self.in_proj_a.weight.float()).to(hidden_states.dtype) + else: + qkv = self.in_proj_qkv(hidden_states) # (B, S, 8192) + z = self.in_proj_z(hidden_states) # (B, S, 4096) + b = self.in_proj_b(hidden_states) # (B, S, 32) + a = self.in_proj_a(hidden_states) # (B, S, 32) + + # Split QKV + query = qkv[..., : self.key_dim] # (B, S, 2048) + key = qkv[..., self.key_dim : self.key_dim * 2] # (B, S, 2048) + value = qkv[..., self.key_dim * 2 :] # (B, S, 4096) + + # Causal Conv1d on QKV (NOT on Z) + mixed = torch.cat([query, key, value], dim=-1) # (B, S, 8192) + mixed = mixed.transpose(1, 2) # (B, 8192, S) + + if is_decode: + # TKG: Use conv_state_buffer for causal conv1d context. + # conv_state_buffer holds the last 3 tokens of pre-silu mixed from CTE/prev TKG. + # We need 4 consecutive values for conv1d kernel_size=4. + # Build window: [conv_state[:, :, 0:3], new_token] = (B, 8192, 4) + # Slice to actual batch_size (buffer may be alloc_batch_size > batch_size) + if seq_ids is not None: + conv_state = torch.index_select( + self.conv_state_buffer, 0, seq_ids + ) # (B, 8192, 3) + else: + conv_state = self.conv_state_buffer[:batch_size] # (B, 8192, 3) + conv_input = torch.cat([conv_state, mixed], dim=-1) # (B, 8192, 4) + + # Apply depthwise conv1d manually (kernel_size=4, groups=conv_dim): + # out = sum_{k=0}^{3} w[:, k] * input[:, :, k] + w = self.conv1d.weight.squeeze(1) # (8192, 4) + conv_out = torch.zeros_like(mixed) # (B, 8192, 1) + for k in range(4): + conv_out = ( + conv_out + + w[:, k].unsqueeze(0).unsqueeze(-1) * conv_input[:, :, k : k + 1] + ) + mixed_post_conv = F.silu(conv_out) + + # Update conv state: shift left, append new token + # new_conv_state = [conv_state[:, :, 1:3], mixed[:, :, 0:1]] = (B, 8192, 3) + new_conv_state = torch.cat( + [conv_state[:, :, 1:], mixed], dim=-1 + ) # (B, 8192, 3) + # Scatter updated state back to correct buffer slots using seq_ids. + alloc_bs = self.conv_state_buffer.shape[0] + if seq_ids is not None: + idx = seq_ids.view(-1, 1, 1).expand_as(new_conv_state) + new_conv_state = (self.conv_state_buffer * 1).scatter( + 0, idx, new_conv_state + ) + elif batch_size < alloc_bs: + pad_size = alloc_bs - batch_size + new_conv_state = torch.cat( + [ + new_conv_state, + self.conv_state_buffer[batch_size:] + * 0, # touch remaining buffer entries + ], + dim=0, + ) + else: + new_conv_state = ( + new_conv_state + self.conv_state_buffer * 0 + ) # touch for alias + else: + # CTE: Use nn.Conv1d with built-in padding (V36 approach -- proven correct). + # self.conv1d has padding=kernel_size-1=3, which pads both sides symmetrically. + # Truncating to [:, :, :seq_len] gives correct causal conv1d output. + # This is IDENTICAL to V36 which produced correct "Paris" output. + skip_conv = os.environ.get("SKIP_CONV1D") == "1" + if skip_conv: + # DEBUG: Skip conv1d entirely — just apply silu to raw mixed + mixed_post_conv = F.silu(mixed[:, :, :seq_len]) + else: + mixed_post_conv = F.silu(self.conv1d(mixed)[:, :, :seq_len]) + + # CRITICAL (V38b): Save last 3 VALID tokens' mixed values for conv_state. + # With right-padding, valid tokens are at positions 0..n-1, padding at n..S-1. + # mixed[:, :, -3:] would capture PADDING positions (all zeros) — WRONG. + # Instead, find the number of valid tokens and gather the last 3. + if valid_mask_1d is not None: + # valid_mask_1d: (B, S) float, 1=valid, 0=padding + # Count valid tokens per batch element + num_valid = valid_mask_1d.sum(dim=-1, keepdim=True).long() # (B, 1) + # Indices for last 3 valid positions: [n-3, n-2, n-1] + # Clamp to 0 to handle case where num_valid < 3 + idx_base = num_valid - 3 # (B, 1) + idx_base = idx_base.clamp(min=0) + offsets = torch.arange(3, device=mixed.device).unsqueeze(0) # (1, 3) + gather_idx = idx_base + offsets # (B, 3) + # Expand for gather: (B, conv_dim, 3) + gather_idx = gather_idx.unsqueeze(1).expand(-1, self.conv_dim, -1) + new_conv_state = torch.gather(mixed, 2, gather_idx) # (B, 8192, 3) + else: + # No mask (shouldn't happen in CTE, but fallback) + new_conv_state = mixed[:, :, -3:].contiguous() + + # IMPORTANT: Touch conv_state_buffer during CTE so XLA can find it + # in the lowering context (it's aliased via input_output_aliases). + # During CTE, we don't USE the old state (nn.Conv1d handles its own + # padding), but the alias requires the parameter to be part of the + # traced graph. Adding * 0 ensures the old buffer is read but has + # no numeric effect on new_conv_state. + # Scatter new_conv_state to correct buffer slots using seq_ids. + alloc_bs = self.conv_state_buffer.shape[0] + if seq_ids is not None: + idx = seq_ids.view(-1, 1, 1).expand_as(new_conv_state) + new_conv_state = (self.conv_state_buffer * 1).scatter( + 0, idx, new_conv_state + ) + elif batch_size < alloc_bs: + pad_size = alloc_bs - batch_size + new_conv_state = torch.cat( + [ + new_conv_state, + torch.zeros( + pad_size, + self.conv_dim, + self.conv_kernel_size - 1, + dtype=new_conv_state.dtype, + device=new_conv_state.device, + ), + ], + dim=0, + ) + new_conv_state = new_conv_state + self.conv_state_buffer * 0 + else: + new_conv_state = new_conv_state + self.conv_state_buffer * 0 + + mixed_post_conv = mixed_post_conv.transpose( + 1, 2 + ) # (B, S, 8192) or (B, 1, 8192) + query = mixed_post_conv[..., : self.key_dim] + key = mixed_post_conv[..., self.key_dim : self.key_dim * 2] + value = mixed_post_conv[..., self.key_dim * 2 :] + + # Reshape to heads + query = query.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) + key = key.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) + value = value.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + + # CRITICAL (V39): Zero out post-conv1d outputs at padding positions. + # The conv1d (kernel_size=4) introduces cross-position mixing, so even though + # we zeroed hidden_states at pad positions BEFORE projection, the conv1d output + # at pad position i can be nonzero due to leakage from valid positions (i-1, i-2, i-3). + # Without this mask, the recurrence accumulates garbage: + # state[t] = exp(0)*state[t-1] + 0.5 * nonzero_key * nonzero_value + # because beta=sigmoid(0)=0.5 and key/value are nonzero from conv leakage. + if valid_mask_1d is not None and not is_decode: + # valid_mask_1d: (B, S) float, 1=valid, 0=padding + # After reshape, tensors are (B, S, H, D) so we need (B, S, 1, 1) mask + post_conv_mask = valid_mask_1d.unsqueeze(-1).unsqueeze(-1) # (B, S, 1, 1) + query = query * post_conv_mask + key = key * post_conv_mask + value = value * post_conv_mask + + # Compute gating + beta = b.sigmoid() # (B, S, num_v_heads) + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + + # CRITICAL (V38b): Zero out g for padding positions. + # g controls the decay factor: exp(g). For padding, g ≈ -1.3 → exp(g) ≈ 0.27. + # With right-padding, 123 padding tokens after 5 valid tokens would decay state + # by 0.27^123 ≈ 10^{-70}, effectively zeroing it. Setting g=0 for padding means + # exp(0)=1, so the state is preserved unchanged through padding positions. + if valid_mask_1d is not None: + # valid_mask_1d: (B, S) float bf16, g: (B, S, num_v_heads) float32 + g = g * valid_mask_1d.float().unsqueeze(-1) # Zero g for padding positions + # Also zero beta at pad positions so recurrence doesn't write: + # state[t] = state[t-1] + 0 * key * value (no-op at pad positions) + beta = beta * valid_mask_1d.unsqueeze(-1) # (B, S, num_v_heads) + + # Expand K heads to match V heads (16 -> 32) using expand+reshape + if self.num_v_heads // self.num_k_heads > 1: + rep = self.num_v_heads // self.num_k_heads # 2 + query = ( + query.unsqueeze(3) + .expand(-1, -1, -1, rep, -1) + .reshape(batch_size, seq_len, self.num_v_heads, self.head_k_dim) + ) + key = ( + key.unsqueeze(3) + .expand(-1, -1, -1, rep, -1) + .reshape(batch_size, seq_len, self.num_v_heads, self.head_k_dim) + ) + + # Transpose to (B, H, S, dim) for delta rule + query = query.transpose(1, 2).contiguous().float() + key = key.transpose(1, 2).contiguous().float() + value = value.transpose(1, 2).contiguous().float() + g = g.transpose(1, 2).contiguous().float() + beta = beta.transpose(1, 2).contiguous().float() + + if is_decode: + # Load recurrent state from buffer (bf16 -> f32) + # Slice to actual batch_size (buffer may be alloc_batch_size > batch_size) + if seq_ids is not None: + recurrent_state = torch.index_select( + self.recurrent_state_buffer, 0, seq_ids + ).float() + else: + recurrent_state = self.recurrent_state_buffer[:batch_size].float() + output, new_recurrent_state = self._recurrent_step( + query, key, value, g, beta, recurrent_state + ) + # Scatter updated state back to correct buffer slots using seq_ids. + alloc_bs = self.recurrent_state_buffer.shape[0] + if seq_ids is not None: + idx = seq_ids.view(-1, 1, 1, 1).expand_as(new_recurrent_state) + new_recurrent_state = (self.recurrent_state_buffer.float() * 1).scatter( + 0, idx, new_recurrent_state + ) + elif batch_size < alloc_bs: + new_recurrent_state = torch.cat( + [ + new_recurrent_state, + self.recurrent_state_buffer[batch_size:].float() * 0, + ], + dim=0, + ) + else: + new_recurrent_state = ( + new_recurrent_state + self.recurrent_state_buffer.float() * 0 + ) + else: + # Context encoding: Use NKI recurrent forward for ACCURACY VALIDATION. + # NKI kernel does recurrence at hardware level, bypassing XLA loop issues. + # Slower than chunk_forward (14.4s vs 2.2s TTFT) but should be correct. + # TODO: Switch back to _chunk_forward once accuracy is validated or + # the XLA variable-width-slice loop miscompilation is fixed. + output, new_recurrent_state = self._nki_recurrent_forward( + query, key, value, g, beta + ) + # IMPORTANT: Touch recurrent_state_buffer during CTE so XLA can find it + # in the lowering context (it's aliased via input_output_aliases). + # During CTE, we don't USE the old state (NKI kernel starts from zero), + # but the alias requires the parameter to be part of the traced graph. + # Adding * 0 ensures the old buffer is read but has no numeric effect. + # Scatter new_recurrent_state to correct buffer slots using seq_ids. + alloc_bs = self.recurrent_state_buffer.shape[0] + if seq_ids is not None: + idx = seq_ids.view(-1, 1, 1, 1).expand_as(new_recurrent_state) + new_recurrent_state = (self.recurrent_state_buffer.float() * 1).scatter( + 0, idx, new_recurrent_state + ) + elif batch_size < alloc_bs: + pad_size = alloc_bs - batch_size + new_recurrent_state = torch.cat( + [ + new_recurrent_state, + torch.zeros( + pad_size, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + dtype=new_recurrent_state.dtype, + device=new_recurrent_state.device, + ), + ], + dim=0, + ) + new_recurrent_state = ( + new_recurrent_state + self.recurrent_state_buffer.float() * 0 + ) + else: + new_recurrent_state = ( + new_recurrent_state + self.recurrent_state_buffer.float() * 0 + ) + + # Cast recurrent state back to storage dtype (f32 -> bf16) + new_recurrent_state = new_recurrent_state.to(hidden_states.dtype) + + # Back to (B, S, H, v_dim) then (B, S, value_dim) + output = output.transpose(1, 2).contiguous().to(hidden_states.dtype) + output = output.reshape(batch_size, seq_len, -1) + + # Gated RMSNorm + output projection + # norm(output) * silu(z) + z_flat = z.reshape(-1, self.head_v_dim) + output_flat = output.reshape(-1, self.head_v_dim) + output_flat = self.norm(output_flat) * F.silu(z_flat) + output = output_flat.reshape(batch_size, seq_len, self.value_dim) + if deltanet_fp32: + output = F.linear(output.float(), self.out_proj.weight.float()).to( + hidden_states.dtype + ) + else: + output = self.out_proj(output) + + # Build dummy KV for KVCacheManager compatibility. + dummy_k = torch.zeros( + batch_size, + self.kv_heads_per_rank, + seq_len, + self.head_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + dummy_v = torch.zeros( + batch_size, + self.kv_heads_per_rank, + seq_len, + self.head_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + dummy_kv = (dummy_k, dummy_v) + + return output, dummy_kv, new_recurrent_state, new_conv_state + + +# ============================================================ +# Config +# ============================================================ + + +class Qwen35MoeInferenceConfig(InferenceConfig): + """Config for Qwen3.5-35B-A3B with hybrid DeltaNet + Attention.""" + + def __init__(self, *args, **kwargs): + # Generate layer_types before super().__init__() which calls validate_config() + if "layer_types" not in kwargs: + layer_types = [] + num_layers = kwargs.get("num_hidden_layers", 48) + for _ in range(num_layers // 4): + layer_types.extend( + [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + ] + ) + # Handle remainder if num_layers not divisible by 4 + remainder = num_layers % 4 + layer_types.extend(["linear_attention"] * min(remainder, 3)) + if remainder == 4: + layer_types.append("full_attention") + kwargs["layer_types"] = layer_types + + super().__init__(*args, **kwargs) + + # Model-specific attributes from text_config + self.num_local_experts = self.num_experts + # Shared expert: save intermediate_size for manual shared expert MLP + self.shared_expert_intermediate_size = getattr( + self, "shared_expert_intermediate_size", 512 + ) + # CRITICAL: Set n_shared_experts=0 for NxDI's MoE module. + # We handle shared experts manually with sigmoid gating in the decoder layer. + # NxDI adds shared expert output directly without gating, which is incorrect + # for Qwen3.5 (requires sigmoid gate). + self.n_shared_experts = 0 + self.intermediate_size = self.moe_intermediate_size + + # Attention output gate + self.attn_output_gate = getattr(self, "attn_output_gate", True) + + # Partial RoPE + self.partial_rotary_factor = getattr(self, "partial_rotary_factor", 0.25) + self.rope_dim = int(self.head_dim * self.partial_rotary_factor) # 64 + + # mRoPE (multimodal RoPE) for VL support + # Extract from rope_parameters if present (HF config format) + rope_params = getattr(self, "rope_parameters", {}) or {} + self.mrope_section = rope_params.get("mrope_section", [11, 11, 10]) + self.mrope_interleaved = rope_params.get("mrope_interleaved", True) + + # Layer types for hybrid dispatch + if not hasattr(self, "layer_types"): + self.layer_types = [] + for _ in range(10): + self.layer_types.extend( + [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + ] + ) + + # Standard HF config attributes expected by NxDI base class + if not hasattr(self, "output_attentions"): + self.output_attentions = False + if not hasattr(self, "output_hidden_states"): + self.output_hidden_states = False + + # DeltaNet-specific config + if not hasattr(self, "linear_num_value_heads"): + self.linear_num_value_heads = 32 + if not hasattr(self, "linear_num_key_heads"): + self.linear_num_key_heads = 16 + if not hasattr(self, "linear_key_head_dim"): + self.linear_key_head_dim = 128 + if not hasattr(self, "linear_value_head_dim"): + self.linear_value_head_dim = 128 + if not hasattr(self, "linear_conv_kernel_dim"): + self.linear_conv_kernel_dim = 4 + + # MoE config + self.maybe_pad_intermediate() + self.enable_moe_fused_nki_kernel() + self.neuron_config.router_config.dtype = torch.float32 + self.neuron_config.router_config.act_fn = "softmax" + self.neuron_config.disable_numeric_cc_token = True + self.neuron_config.normalize_top_k_affinities = True + + def add_derived_config(self): + """Promote text_config fields before validation. + + When loaded via vLLM, the HF config has a multimodal-style layout + where model fields live inside text_config rather than at the top + level. This method promotes them so that validate_config() and the + rest of the model can access them as direct attributes. + """ + if hasattr(self, "text_config") and not hasattr(self, "hidden_size"): + tc = self.text_config + for attr in dir(tc): + if not attr.startswith("_") and not hasattr(self, attr): + setattr(self, attr, getattr(tc, attr)) + + # rope_theta lives inside rope_parameters in the HF config + if not hasattr(self, "rope_theta"): + rope_params = getattr(self, "rope_parameters", None) + if rope_params is not None: + if isinstance(rope_params, dict): + self.rope_theta = rope_params.get("rope_theta", 10000000) + else: + self.rope_theta = getattr(rope_params, "rope_theta", 10000000) + + super().add_derived_config() + + def maybe_pad_intermediate(self): + moe_tp_degree = self.neuron_config.moe_tp_degree + I_TP = self.moe_intermediate_size // moe_tp_degree + if getattr( + self.neuron_config.blockwise_matmul_config, + "use_shard_on_intermediate_dynamic_while", + False, + ): + if I_TP % SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP != 0: + padded = ( + math.ceil(I_TP / SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP) + * SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP + * moe_tp_degree + ) + self.moe_intermediate_pad_size = max( + padded - self.moe_intermediate_size, 0 + ) + self.moe_intermediate_size = padded + + def enable_moe_fused_nki_kernel(self): + I_TP = self.moe_intermediate_size // self.neuron_config.moe_tp_degree + if ( + getattr(self.neuron_config, "moe_fused_nki_kernel_enabled", False) + and I_TP % MOE_TKG_MK_INTERMEDIATE_PER_TP == 0 + ): + self.moe_fused_nki_kernel_enabled = True + # Patch the fused TKG kernel to use the ISA router fallback. + # The NKI router_topk_kernel asserts that input and weight dtypes + # match, but our router uses float32 for accuracy while weights + # are bfloat16. The ISA fallback handles mixed dtypes correctly. + _patch_fused_tkg_for_fp32_router() + + def get_required_attributes(self) -> List[str]: + return [ + "head_dim", + "hidden_act", + "hidden_size", + "max_position_embeddings", + "moe_intermediate_size", + "num_attention_heads", + "num_experts", + "num_experts_per_tok", + "num_hidden_layers", + "num_key_value_heads", + "rms_norm_eps", + "rope_theta", + "vocab_size", + # DeltaNet-specific + "linear_num_value_heads", + "linear_num_key_heads", + "linear_key_head_dim", + "linear_value_head_dim", + "linear_conv_kernel_dim", + "layer_types", + ] + + @classmethod + def get_neuron_config_cls(cls): + return MoENeuronConfig + + +# ============================================================ +# Attention (standard GQA for 10 of 40 layers) +# With output gate: q_proj is 2x sized, split into (query, gate) +# With partial RoPE: only first rope_dim dimensions get rotary +# ============================================================ + + +class Qwen35MRoPEEmbedding(nn.Module): + """Multimodal Rotary Position Embedding (mRoPE) for Qwen3.5. + + Handles 3D position information (temporal, height, width) for VL models. + Position IDs have shape (3, batch_size, seq_len) for T/H/W dimensions. + For text-only (2D position_ids), broadcasts to 3D with identical positions. + + Uses interleaved layout: THWTHW... (stride-3 indexing) matching HF reference. + + Based on Qwen3-VL-8B-Thinking contrib model, adapted for partial RoPE: + - dim = rope_dim (64), not full head_dim (256) + - mrope_section = [11, 11, 10] (total 32 = rope_dim / 2) + - Output (cos, sin) shape: (batch_size, seq_len, rope_dim=64) + """ + + def __init__(self, config: Qwen35MoeInferenceConfig): + super().__init__() + self.dim = config.rope_dim # 64 (partial RoPE) + self.max_position_embeddings = config.max_position_embeddings + self.base = config.rope_theta + + # mRoPE specific configuration + self.mrope_section = getattr(config, "mrope_section", [11, 11, 10]) + self.mrope_interleaved = getattr(config, "mrope_interleaved", True) + + # inv_freq: (rope_dim / 2,) = (32,) -- matches sum(mrope_section) + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def apply_interleaved_mrope(self, freqs, mrope_section): + """Apply interleaved mRoPE to 3D rotary embeddings. + + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THWTHW...TT], preserving frequency continuity. + + Args: + freqs: (3, bs, seq_len, rope_dim // 2) - frequencies for T, H, W + mrope_section: (3,) - sections for temporal, height, width + + Returns: + freqs_t: (bs, seq_len, rope_dim // 2) - interleaved frequencies + """ + freqs_t = freqs[0].clone() # Start with temporal frequencies + for dim, offset in enumerate((1, 2), start=1): # H, W + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + def forward(self, x, position_ids): + """Compute mRoPE cos/sin embeddings. + + Args: + x: Input tensor (for device/dtype only) + position_ids: (3, batch_size, seq_len) or (batch_size, seq_len) + + Returns: + cos: (batch_size, seq_len, rope_dim=64) + sin: (batch_size, seq_len, rope_dim=64) + """ + # Expand to 3D if needed + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + # (rope_dim/2,) -> (3, batch_size, rope_dim/2, 1) + inv_freq_expanded = ( + self.inv_freq[None, None, :, None] + .float() + .expand(3, position_ids.shape[1], -1, 1) + ) + + # (3, batch_size, seq_len) -> (3, batch_size, 1, seq_len) + position_ids_expanded = position_ids[:, :, None, :].float() + + # Compute frequencies per dimension: (3, bs, rope_dim/2, seq_len) -> (3, bs, seq_len, rope_dim/2) + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(2, 3) + + # Apply interleaved mRoPE + if self.mrope_interleaved: + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + else: + freqs = freqs[0] + + # Double to rope_dim: (bs, seq_len, rope_dim/2) -> (bs, seq_len, rope_dim) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class NeuronQwen35Attention(NeuronAttentionBase): + """Attention with output gate and partial RoPE. + + V3: Implements partial RoPE (25% of head_dim), per-head QK norm, + and output gate (sigmoid gate applied BEFORE o_proj, matching reference). + + HF weight layout: + - q_proj.weight: (num_heads * head_dim * 2, hidden_size) = (8192, 2048) + First half is query, second half is gate + - k_proj.weight: (num_kv_heads * head_dim, hidden_size) = (512, 2048) + - v_proj.weight: (num_kv_heads * head_dim, hidden_size) = (512, 2048) + - o_proj.weight: (hidden_size, num_heads * head_dim) = (2048, 4096) + - q_norm.weight: (head_dim,) = (256,) + - k_norm.weight: (head_dim,) = (256,) + """ + + def __init__(self, config: Qwen35MoeInferenceConfig): + # Partial RoPE: create mRoPE embedding with rope_dim (64) + self.rope_dim = config.rope_dim # 64 = head_dim * partial_rotary_factor + + # Create QK norm modules first (will be passed to base class) + rms_norm_eps = config.rms_norm_eps + q_ln = get_rmsnorm_cls()(config.head_dim, rms_norm_eps) + k_ln = get_rmsnorm_cls()(config.head_dim, rms_norm_eps) + + # Partial RoPE: use standard RotaryEmbedding (identical to pre-mRoPE working code). + # For VL with 3D mRoPE positions, cos/sin are pre-computed externally in + # get_model_output() using Qwen35MRoPEEmbedding and passed as cos_cache/sin_cache. + rotary_emb = RotaryEmbedding( + self.rope_dim, # Only 64 dims get rotary embedding + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + super().__init__( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.head_dim, + rotary_emb=rotary_emb, + rms_norm_eps=rms_norm_eps, + use_qk_norm=False, + q_layernorm=q_ln, + k_layernorm=k_ln, + ) + + # Separate mRoPE module for VL 3D position_ids (not used as self.rotary_emb). + # When rotary_position_ids is 3D (T/H/W), we pre-compute cos/sin with mRoPE + # and pass them to prep_qkv_tensors via cos_cache/sin_cache. + self.mrope_emb = Qwen35MRoPEEmbedding(config) + + # Output gate projection: hidden_size -> num_heads * head_dim (4096) + # This is populated from the second half of q_proj during state dict conversion. + # Use ColumnParallelLinear so it gets sharded across TP ranks, + # matching the per-rank attention output shape. + self.output_gate_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * config.head_dim, + gather_output=False, # Each rank keeps its shard + bias=False, + ) + + def apply_rotary_embedding( + self, Q, K, V, position_ids, cos_cache, sin_cache, use_polar_compatible_rope + ): + """Partial RoPE: only apply rotary embedding to first rope_dim dimensions. + + Q shape: (B, H, S, head_dim) where head_dim=256 + cos/sin shape: (B, S, rope_dim) where rope_dim=64 (from RotaryEmbedding(dim=64)) + + During CTE (prefill) with seq_len >= 512: skip RoPE here — the NKI flash + attention kernel applies partial RoPE internally using fused cos/sin caches. + This avoids the Beta 2 NKI tracer bug (V2169383883). + + During CTE with seq_len < 512: apply RoPE normally — the softmax fallback + path does NOT apply RoPE internally, so it must be done here. + + During TKG (decode): apply RoPE normally (kernel not used for decode). + """ + from neuronx_distributed_inference.modules.attention.utils import ( + apply_rotary_pos_emb, + ) + + if self.rotary_emb is not None: + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(V, position_ids) + + # During CTE with d256 kernel: skip RoPE ONLY when the NKI flash attention + # kernel will be used (seq_len >= 512 and divisible by 512), since that + # kernel applies partial RoPE internally. For shorter sequences that fall + # back to the softmax path, we MUST apply RoPE here. + seq_len = Q.shape[2] + if ( + self.neuron_config.is_prefill_stage + and self.head_dim > 128 + and seq_len >= 512 + and seq_len % 512 == 0 + ): + # Return pre-RoPE Q, K — kernel applies partial RoPE with cos/sin + return Q, K, cos_cache, sin_cache + + # TKG path: apply RoPE normally + # Split into rope and pass-through portions + q_rope = Q[..., : self.rope_dim] # (B, H, S, 64) + q_pass = Q[..., self.rope_dim :] # (B, H, S, 192) + k_rope = K[..., : self.rope_dim] + k_pass = K[..., self.rope_dim :] + + # Apply RoPE only to the rope portion + q_rope, k_rope = apply_rotary_pos_emb(q_rope, k_rope, cos_cache, sin_cache) + + # Concatenate back + Q = torch.cat([q_rope, q_pass], dim=-1) + K = torch.cat([k_rope, k_pass], dim=-1) + + return Q, K, cos_cache, sin_cache + + def perform_prefill( + self, Q, K, V, q_len, bsz, attention_mask, cos_cache=None, sin_cache=None + ): + """Override to handle head_dim=256 with custom NKI flash attention kernel. + + The standard NxDI NKI flash attention kernel asserts head_dim <= 128. + We use our own flash_attn_d256_pipe kernel which tiles the QK contraction + into 2x128 chunks with 3-stage software pipelining. + + With fused RoPE: Q/K are pre-RoPE, cos/sin caches are passed to the + kernel which applies partial RoPE internally. This avoids the Beta 2 + NKI tracer bug (V2169383883). + + Shape contract: + Input Q: (B, H, S, D=256) -- BHSD (pre-RoPE when cos/sin provided) + Input K: (B, Hkv, S, D=256) -- BHSD (pre-RoPE when cos/sin provided) + Input V: (B, Hkv, S, D=256) -- BHSD + cos_cache: (B, S, rope_dim=64) or None + sin_cache: (B, S, rope_dim=64) or None + Output: (B, H, S, D) -- BHSD + + The kernel requires seq_len divisible by 512 (B_F tile size). + For smaller seq_lens, fall back to softmax path. + """ + if self.head_dim > 128 and q_len >= 512 and q_len % 512 == 0: + # Pass Q, K, V in BHSD layout to d=256 pipelined kernel + q_kernel = Q.to(self.torch_dtype) + k_kernel = K.to(self.torch_dtype) + v_kernel = V.contiguous().to(self.torch_dtype) + + # Prepare cos/sin for kernel: squeeze batch dim (B, S, 64) -> (S, 64) + fuse_rope = cos_cache is not None and sin_cache is not None + if fuse_rope: + cos_kernel = cos_cache[0].to(self.torch_dtype) # (S, 64) + sin_kernel = sin_cache[0].to(self.torch_dtype) # (S, 64) + else: + cos_kernel = None + sin_kernel = None + + n_kv_heads = K.shape[1] + n_q_heads = Q.shape[1] + q_h_per_kv = n_q_heads // n_kv_heads + + # Per-(batch, kv_head) loop — kernel processes one KV head at a time + out_parts = [] + for b in range(bsz): + for kv_h in range(n_kv_heads): + q_slice = q_kernel[ + b : b + 1, kv_h * q_h_per_kv : (kv_h + 1) * q_h_per_kv, :, : + ] + k_slice = k_kernel[b : b + 1, kv_h : kv_h + 1, :, :] + v_slice = v_kernel[b : b + 1, kv_h : kv_h + 1, :, :] + o_part = _flash_attn_d256_kernel( + q_slice, + k_slice, + v_slice, + cos_cache=cos_kernel, + sin_cache=sin_kernel, + use_causal_mask=True, + q_h_per_k_h=q_h_per_kv, + n_kv_heads=1, + seqlen_q=q_len, + seqlen_kv=q_len, + rope_dim=self.rope_dim if fuse_rope else 0, + ) + out_parts.append(o_part) + + # Reassemble: each o_part is (1, q_h_per_kv, S, D) + attn_output = torch.cat(out_parts, dim=1) # (1, total_heads, S, D) + if bsz > 1: + attn_output = attn_output.reshape(bsz, n_q_heads, q_len, self.head_dim) + + return attn_output, FlashAttentionStrategy.NONE + + if self.head_dim > 128: + # Fallback for seq_lens not divisible by 512 + saved = self.attn_kernel_enabled + self.attn_kernel_enabled = False + result = super().perform_prefill(Q, K, V, q_len, bsz, attention_mask) + self.attn_kernel_enabled = saved + return result + return super().perform_prefill(Q, K, V, q_len, bsz, attention_mask) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + position_ids=None, + past_key_value=None, + active_mask=None, + adapter_ids=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + rotary_position_ids=None, + **kwargs, + ): + """Forward with output gate applied BEFORE o_proj. + + Override NeuronAttentionBase.forward() to insert the sigmoid gate + between the attention output and o_proj, matching the HF reference: + gate = sigmoid(gate_proj(pre_attn_hidden)) + attn_output = attn_output * gate + attn_output = o_proj(attn_output) + + Phase 2 mRoPE: cos_cache/sin_cache are pre-computed from 3D mRoPE + position_ids in get_model_output() and passed through the decoder loop. + When they arrive non-None, apply_rotary_embedding skips self.rotary_emb() + and uses the pre-computed values directly. For TKG (cos/sin=None), + self.rotary_emb computes from 2D position_ids as before. + """ + bsz, q_len, _ = hidden_states.shape + + # Use standard 2D position_ids for prep_qkv_tensors. + # When cos/sin are pre-computed (mRoPE), apply_rotary_embedding skips + # self.rotary_emb() and uses them directly. When None (TKG), + # self.rotary_emb computes from these 2D position_ids. + rope_pos_ids = position_ids + + # Compute gate from input hidden states (before QKV projection) + gate = self.output_gate_proj(hidden_states) # (B, S, num_heads * head_dim) + + # Standard QKV prep (projections, QK norm, RoPE) + Q, K, V, cos_cache, sin_cache, _residual = self.prep_qkv_tensors( + rope_pos_ids, + hidden_states, + past_key_value, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + rmsnorm=rmsnorm, + ) + + if past_key_value is None: + # Context encoding (prefill) — pass cos/sin for fused RoPE in kernel + attn_output, _flash_strategy = self.perform_prefill( + Q, + K, + V, + q_len, + bsz, + attention_mask, + cos_cache=cos_cache, + sin_cache=sin_cache, + ) + else: + # Token generation (decode) + # Fix BS>1: compute_for_token_gen expects 4D mask (B, H, q_len, S) + # but NxDI passes 2D (B, S). With BS=1, (1, S) broadcasts to + # (1,1,1,S) correctly. With BS>1, (B, S) right-aligns to (1,1,B,S) + # causing the batch dim to leak into q_len in torch.where. + tkg_mask = attention_mask + if tkg_mask is not None and tkg_mask.ndim == 2: + tkg_mask = tkg_mask.unsqueeze(1).unsqueeze(2) # (B, S) -> (B, 1, 1, S) + attn_output = self.compute_for_token_gen( + Q, K, V, position_ids, past_key_value, tkg_mask, active_mask + ) + + # attn_output is (B, H, S, head_dim) -- transpose to (B, S, H, head_dim) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + + # Apply sigmoid output gate BEFORE o_proj (matching HF reference) + attn_output = attn_output * torch.sigmoid(gate) + + # Apply o_proj + attn_output = self.get_o_proj()(attn_output, adapter_ids=adapter_ids) + + past_key_value = (K, V) + return attn_output, past_key_value, cos_cache, sin_cache + + +# ============================================================ +# Sigmoid-Gated Shared Expert Wrapper +# ============================================================ + + +class SigmoidGatedSharedExperts(nn.Module): + """Wrapper around NxDI SharedExperts that adds Qwen3.5's sigmoid gate. + + Computes: output = sigmoid(x @ gate_weight.T) * SharedExpertMLP(x) + + The wrapped SharedExperts uses ColumnParallelLinear (gate/up) and + RowParallelLinear(reduce_output=False) (down), so its MLP output is + TP-partial. When called inside MoE._apply_shared_experts (CTE path), + the all-reduce is handled by MoE after the addition. When called + standalone (TKG path), the caller must reduce the output. + + Weight layout: + shared_experts.gate_proj.weight: (intermediate_size, hidden_size) = (512, 2048) + shared_experts.up_proj.weight: (intermediate_size, hidden_size) = (512, 2048) + shared_experts.down_proj.weight: (hidden_size, intermediate_size) = (2048, 512) + sigmoid_gate.weight: (1, hidden_size) = (1, 2048) + """ + + def __init__(self, config): + super().__init__() + from neuronx_distributed.modules.moe.shared_experts import SharedExperts + + self.shared_experts = SharedExperts( + hidden_size=config.hidden_size, + intermediate_size=config.shared_expert_intermediate_size, + num_shared_experts=1, + hidden_act=config.hidden_act, + dtype=config.neuron_config.torch_dtype, + reduce_dtype=config.neuron_config.rpl_reduce_dtype, + fused_gate_up_projection=False, + sequence_parallel_enabled=False, + ) + + # Sigmoid gate: linear(hidden_size -> 1) applied to full hidden states + self.sigmoid_gate = nn.Linear(config.hidden_size, 1, bias=False) + + # Store EP degree for scaling in forward (constant, safe for tracing) + self.ep_degree = config.neuron_config.ep_degree + + @property + def sequence_parallel_enabled(self): + """Expose sequence_parallel_enabled so MoE._apply_shared_experts works.""" + return self.shared_experts.sequence_parallel_enabled + + def preshard_hook(self, model_state_dict, prefix): + """Delegate preshard to inner SharedExperts.""" + self.shared_experts.preshard_hook(model_state_dict, prefix) + + def forward(self, x: torch.Tensor, seq_len: int) -> torch.Tensor: + """Compute sigmoid-gated shared expert output (TP-partial). + + Args: + x: (T, H) flattened hidden states (full, not TP-partial) + seq_len: sequence length + + Returns: + output: (T, H) sigmoid-gated shared expert output (TP-partial from down_proj) + + Note on EP scaling: + When Expert Parallelism is active, this output is added to the routed-expert + output (also TP-partial) and then all-reduced across the world_group (TP*EP). + The routed output is unique per EP rank (each handles different experts), + but this shared output is IDENTICAL across EP ranks. Without correction the + world_group all-reduce sums it ep_degree times. We compensate by dividing + by ep_degree here so the final reduced value is correct. + """ + # Compute shared expert MLP output (TP-partial from down_proj) + shared_output = self.shared_experts(x, seq_len) + + # Apply sigmoid gate: sigmoid(x @ gate_weight.T) -> (T, 1) + gate_value = torch.sigmoid(self.sigmoid_gate(x)) # (T, 1) + output = shared_output * gate_value + + # Scale down by EP degree to compensate for world_group all-reduce overcounting. + # This only matters when ep_degree > 1 (CTE path). For TKG, the decoder handles + # shared experts separately with its own TP-only all-reduce. + if self.ep_degree > 1: + output = output / self.ep_degree + + return output + + +# ============================================================ +# Decoder Layer (hybrid dispatch) +# ============================================================ + + +class NeuronQwen35DecoderLayer(nn.Module): + """Hybrid decoder layer: dispatches to DeltaNet or standard attention. + + Interface contract with NxDI get_model_output: + - forward() receives: hidden_states, seq_ids, attention_mask, position_ids, + past_key_value, active_mask, adapter_ids, cos_cache, sin_cache, + rotary_position_ids, kv_mgr, get_kv_per_layer, update_kv_per_layer, + idx, is_for_context_encoding, seq_len, residual, local_mask, + windowed_context_encoding_window_idx, padding_mask, **kwargs + - forward() returns: (hidden_states, present_key_value, cos_cache, sin_cache, None) + """ + + def __init__(self, config: Qwen35MoeInferenceConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_type = config.layer_types[layer_idx] + self.layer_idx = layer_idx + self.config = config + + # Attention (DeltaNet or standard GQA) + if self.layer_type == "linear_attention": + self.linear_attn = NeuronGatedDeltaNet(config, layer_idx) + else: + self.self_attn = NeuronQwen35Attention(config=config) + + # MoE (all layers) -- uses NxDI's initialize_moe_module + # n_shared_experts=0 so NxDI MoE creates no shared experts internally + self.moe_fused_nki_kernel_enabled = getattr( + config, "moe_fused_nki_kernel_enabled", False + ) + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + + if self.moe_fused_nki_kernel_enabled: + # Fused TKG mega-kernel: pass rmsnorm=None to avoid aliasing. + # CTE path: decoder applies post_attention_layernorm before MoE. + # TKG path: fused kernel applies norm internally via its own RMSNorm. + self.mlp = initialize_moe_module( + config=config, + rmsnorm=None, + init_tkg_module=True, + ) + # Create separate (non-shared) RMSNorm for fused TKG kernel + if ( + hasattr(self.mlp, "moe_fused_tkg") + and self.mlp.moe_fused_tkg is not None + ): + self.mlp.moe_fused_tkg.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + else: + self.mlp = initialize_moe_module(config=config) + + # Sigmoid-gated shared expert using NxDI's TP-sharded SharedExperts. + # Injected into MoE.shared_experts so CTE's _apply_shared_experts handles + # it correctly (TP-partial addition before all-reduce). + # Note: MoEFusedTKG.shared_experts was set to None at init time (before + # this assignment), so the fused TKG kernel skips shared experts. + # During TKG, we handle shared experts manually in the decoder forward. + self.mlp.shared_experts = SigmoidGatedSharedExperts(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + position_ids=None, + past_key_value=None, + padding_mask=None, + cos_cache=None, + sin_cache=None, + **kwargs, + ): + # V30 identity diagnostic: skip all layer computation, return input unchanged + if os.environ.get("SKIP_LAYER_COMPUTE") == "1": + bsz, seq_len, _ = hidden_states.shape + # Create dummy KV cache matching NxDI expected shape: + # (B, kv_heads_per_rank, seq_len, head_dim) + tp = getattr(self.config.neuron_config, "tp_degree", 1) + kv_heads_per_rank = max(self.config.num_key_value_heads // tp, 1) + dummy_k = torch.zeros( + bsz, + kv_heads_per_rank, + seq_len, + self.config.head_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + dummy_v = torch.zeros_like(dummy_k) + hidden_states = ModuleMarkerStartWrapper()(hidden_states) + hidden_states = ModuleMarkerEndWrapper()(hidden_states) + return (hidden_states, (dummy_k, dummy_v), None, None, None, None) + + # V34: Test just input_layernorm + a single projection to isolate where divergence starts + if ( + os.environ.get("DELTANET_PROJ_ONLY") == "1" + and self.layer_type == "linear_attention" + ): + bsz, seq_len, _ = hidden_states.shape + tp = getattr(self.config.neuron_config, "tp_degree", 1) + kv_heads_per_rank = max(self.config.num_key_value_heads // tp, 1) + dummy_k = torch.zeros( + bsz, + kv_heads_per_rank, + seq_len, + self.config.head_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + dummy_v = torch.zeros_like(dummy_k) + + hidden_states = ModuleMarkerStartWrapper()(hidden_states) + # Apply input_layernorm + normed = self.input_layernorm(hidden_states) + # Do one projection to test matmul + proj_out = self.linear_attn.in_proj_qkv(normed) + # Return a scaled version so the output is meaningful + # Use mean of first 2048 dims as a scalar multiplier for the residual + # This avoids having to do the full DeltaNet computation + scale = proj_out[..., : self.hidden_size].mean(dim=-1, keepdim=True) + hidden_states = ( + hidden_states + scale * 0.001 + ) # tiny perturbation from projection + hidden_states = ModuleMarkerEndWrapper()(hidden_states) + return (hidden_states, (dummy_k, dummy_v), None, None, None, None) + + residual = hidden_states + + hidden_states = ModuleMarkerStartWrapper()(hidden_states) + hidden_states = self.input_layernorm(hidden_states) + + # V30 diagnostic: SKIP_ATTN_COMPUTE skips attention, keeps MoE + skip_attn = os.environ.get("SKIP_ATTN_COMPUTE") == "1" + + if self.layer_type == "linear_attention": + if skip_attn: + # Skip DeltaNet, just use residual + hidden_states = residual + bsz, seq_len, _ = hidden_states.shape + tp = getattr(self.config.neuron_config, "tp_degree", 1) + kv_heads_per_rank = max(self.config.num_key_value_heads // tp, 1) + present_key_value = ( + torch.zeros( + bsz, + kv_heads_per_rank, + seq_len, + self.config.head_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ), + torch.zeros( + bsz, + kv_heads_per_rank, + seq_len, + self.config.head_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ), + ) + deltanet_states = None + else: + # DeltaNet path -- returns (output, dummy_kv, new_recurrent, new_conv) + attn_out, dummy_kv, new_rec_state, new_conv_state = self.linear_attn( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + padding_mask=padding_mask, + **kwargs, + ) + hidden_states = residual + attn_out + present_key_value = dummy_kv + deltanet_states = (new_rec_state, new_conv_state) + # Pass through cos/sin cache (pre-computed mRoPE from get_model_output) + # instead of resetting to None, so subsequent GQA layers receive them. + else: + deltanet_states = None + if skip_attn: + hidden_states = residual + bsz, seq_len, _ = hidden_states.shape + tp = getattr(self.config.neuron_config, "tp_degree", 1) + kv_heads_per_rank = max(self.config.num_key_value_heads // tp, 1) + present_key_value = ( + torch.zeros( + bsz, + kv_heads_per_rank, + seq_len, + self.config.head_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ), + torch.zeros( + bsz, + kv_heads_per_rank, + seq_len, + self.config.head_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ), + ) + cos_cache, sin_cache = None, None + else: + # Standard attention path (V3: gate is inside self_attn.forward()) + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + cos_cache=cos_cache, + sin_cache=sin_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # V30 diagnostic: SKIP_MOE_COMPUTE skips MoE, keeps attention + skip_moe = os.environ.get("SKIP_MOE_COMPUTE") == "1" + + if skip_moe: + # Skip MoE entirely, hidden_states stays as-is (attn residual) + pass + else: + # MoE FFN (routed experts + sigmoid-gated shared expert) + # Shared expert is handled OUTSIDE the MoE module because the fused + # TKG kernel cannot apply the sigmoid gate. + residual = hidden_states + + # Normalization strategy for fused MoE TKG: + # - CTE (seq_len > 1): Decoder applies post_attention_layernorm. + # MoE's _forward_compute_bound skips norm (rmsnorm=None). + # - TKG (seq_len == 1): Decoder skips post_attention_layernorm. + # Fused kernel applies norm internally using its own RMSNorm. + is_tkg = self.moe_fused_nki_kernel_enabled and hidden_states.shape[1] == 1 + if not is_tkg: + hidden_states = self.post_attention_layernorm(hidden_states) + + is_speculative_decoding = ( + self.config.neuron_config.enable_fused_speculation + and not self.config.neuron_config.is_prefill_stage + ) + moe_output = self.mlp( + hidden_states, + padding_mask, + is_speculative_decoding=is_speculative_decoding, + )[0] + + # Shared expert handling depends on path: + # - CTE: MoE._apply_shared_experts already added shared expert output + # (TP-partial + TP-partial, then all-reduce). Nothing more to do. + # - TKG: Fused kernel skipped shared experts (MoEFusedTKG.shared_experts + # is None). We compute it manually and add with an extra all-reduce. + if is_tkg: + from neuronx_distributed.parallel_layers import ( + mappings, + parallel_state, + ) + + shared_input = self.post_attention_layernorm(residual) + shared_input_flat = shared_input.reshape(-1, shared_input.shape[-1]) + # shared_output is TP-partial (from down_proj reduce_output=False). + # Note: SigmoidGatedSharedExperts.forward() divides by ep_degree to + # correct for world_group overcounting in the CTE path. But in TKG + # the all-reduce is TP-only (no overcounting), so we undo the scaling. + shared_output = self.mlp.shared_experts( + shared_input_flat, shared_input.shape[1] + ) + ep_degree = self.config.neuron_config.ep_degree + if ep_degree > 1: + shared_output = shared_output * ep_degree + # All-reduce across TP group only (not world_group) + shared_output = mappings.reduce_from_tensor_model_parallel_region( + shared_output, + process_group=parallel_state.get_tensor_model_parallel_group(), + ) + shared_output = shared_output.view(moe_output.shape) + moe_output = moe_output + shared_output + + hidden_states = residual + moe_output + + hidden_states = ModuleMarkerEndWrapper()(hidden_states) + outputs = ( + hidden_states, + present_key_value, + cos_cache, + sin_cache, + None, + deltanet_states, + ) + return outputs + + +# ============================================================ +# Model +# ============================================================ + + +class NeuronQwen35MoeModel(NeuronBaseModel): + def setup_attr_for_model(self, config: Qwen35MoeInferenceConfig): + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config: Qwen35MoeInferenceConfig): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + ) + self.layers = nn.ModuleList( + [ + NeuronQwen35DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = get_rmsnorm_cls()(self.hidden_size, eps=config.rms_norm_eps) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=False if self.on_device_sampling else True, + bias=False, + ) + + # mRoPE embedding for VL: pre-computes cos/sin from 3D position_ids + # in get_model_output() before the decoder layer loop. + self.mrope_emb = Qwen35MRoPEEmbedding(config) + + @property + def _deltanet_state_params(self): + """Return DeltaNet state nn.Parameters in alias order. + + Order: for each DeltaNet layer, (recurrent_state, conv_state). + Used by Qwen35DecoderModelInstance to set up input_output_aliases. + Returns fresh references each time (load_state_dict may replace .data). + """ + params = [] + for layer in self.layers: + if hasattr(layer, "linear_attn"): + params.append(layer.linear_attn.recurrent_state_buffer) + params.append(layer.linear_attn.conv_state_buffer) + return params + + def encode_vision_to_input(self, inputs_embeds, vision_embeddings, vision_mask): + """Scatter vision embeddings into text input embeddings at image token positions. + + Uses index_put_ to replace placeholder token embeddings with vision encoder output. + + Args: + inputs_embeds: (batch_size, seq_len, hidden_size) -- text token embeddings + vision_embeddings: (batch_size, n_vision_tokens, hidden_size) -- from vision encoder + vision_mask: (batch_size, n_vision_tokens, 1) -- int32 position indices + + Returns: + inputs_embeds with vision embeddings scattered in at the specified positions + """ + _, max_positions, embedding_dim = inputs_embeds.shape + h_new = inputs_embeds.clone() + vision_flat = vision_embeddings.view(-1, embedding_dim) + positions_flat = vision_mask.view(-1) + h_new.view(-1, embedding_dim).index_put_( + (positions_flat,), vision_flat, accumulate=False + ) + return h_new + + def get_model_output( + self, + input_ids=None, + seq_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + active_mask=None, + inputs_embeds=None, + prev_hidden=None, + adapter_ids=None, + rotary_position_ids=None, + update_cache=False, + is_for_context_encoding=False, + vision_embeddings=None, + vision_mask=None, + local_attn_mask=None, + windowed_context_encoding_window_idx=-1, + padding_mask=None, + **kwargs, + ): + """Override to collect DeltaNet state tensors from decoder layers. + + Calls the parent get_model_output logic but extracts the 6th element + (deltanet_states) from each decoder layer's output and collects them + into a flat list that will be appended to the model output. + """ + batch_size, seq_length = input_ids.shape[:2] + if self.config.neuron_config.layer_boundary_markers: + input_ids = ModuleMarkerStartWrapper()(input_ids) + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][1].shape[2] + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # Vision embedding injection (scatter vision tokens into text embeddings) + if (vision_embeddings is not None) and (vision_mask is not None): + if vision_embeddings.dtype != self.config.neuron_config.torch_dtype: + vision_embeddings = vision_embeddings.to( + self.config.neuron_config.torch_dtype + ) + if is_for_context_encoding: + inputs_embeds = self.encode_vision_to_input( + inputs_embeds, vision_embeddings, vision_mask + ) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + hidden_states = inputs_embeds + + # Get KV cache for TKG + cache_size = self.n_positions + if not is_for_context_encoding: + if self.kv_mgr is not None: + past_key_values = self.kv_mgr.get_cache( + seq_ids=seq_ids, + seq_len=cache_size, + is_for_context_encoding=is_for_context_encoding, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + + # Decoder layers + next_decoder_cache = () + deltanet_state_tensors = [] # Collect DeltaNet states + cos_cache = None + sin_cache = None + + # CRITICAL V42: Save the 2D attention_mask BEFORE converting to 4D. + # This is the raw traced input [1,1,...,1,0,0,...,0] (valid=1, pad=0). + # We pass this directly to DeltaNet layers as padding_mask. + # Because it's a direct function of a traced input (no comparison/reduction), + # XLA CANNOT constant-fold it away. This is the correct way to pass + # padding information to DeltaNet layers. + padding_mask_2d = None + if ( + attention_mask is not None + and attention_mask.ndim == 2 + and is_for_context_encoding + ): + # Cast to float (bf16) for multiplication with hidden_states later + padding_mask_2d = attention_mask.to(torch.bfloat16) # (B, S) float + + # Convert 2D attention_mask (B, S) to 4D causal mask (B, 1, S, S) for + # the softmax attention fallback path (perform_prefill with head_dim>128). + # With BS=1, the 2D mask broadcasts accidentally. With BS>1, it doesn't. + # Flash attention doesn't use this mask (uses internal causal masking). + if ( + attention_mask is not None + and attention_mask.ndim == 2 + and is_for_context_encoding + ): + # Build causal mask: position i can attend to positions 0..i + causal = torch.ones( + (seq_length, seq_length), + dtype=torch.bool, + device=attention_mask.device, + ).tril() + # Combine with padding mask: (B, 1, 1, S) & (1, 1, S, S) + padding_4d = attention_mask[:, None, None, :].to(torch.bool) # (B, 1, 1, S) + attention_mask = (causal[None, None, :, :] & padding_4d).to( + attention_mask.dtype + ) # (B, 1, S, S) + + # Phase 2 mRoPE: Pre-compute cos/sin from 3D position_ids when available. + # For CTE with VL content, rotary_position_ids is (3, B, S) with T/H/W positions. + # For text-only CTE, it's (3, B, S) with T=H=W=sequential. + # For TKG, it's None (set_none_if_empty converted torch.zeros((0,)) to None). + # Pre-computing here ensures all layers receive the same mRoPE cos/sin, + # and DeltaNet layers pass them through unchanged. + if rotary_position_ids is not None and rotary_position_ids.ndim == 3: + cos_cache, sin_cache = self.mrope_emb(inputs_embeds, rotary_position_ids) + + for idx, decoder_layer in enumerate(self.layers): + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + layer_outputs = decoder_layer( + hidden_states, + seq_ids=seq_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + active_mask=active_mask, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + rotary_position_ids=rotary_position_ids, + kv_mgr=self.kv_mgr, + get_kv_per_layer=False, + update_kv_per_layer=False, + idx=idx, + is_for_context_encoding=is_for_context_encoding, + seq_len=cache_size, + residual=None, + local_mask=local_attn_mask, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + padding_mask=padding_mask_2d + if padding_mask_2d is not None + else padding_mask, + **kwargs, + ) + + hidden_states = layer_outputs[0] + kv = layer_outputs[1] + next_decoder_cache += (kv,) + cos_cache, sin_cache = layer_outputs[2:4] + + # Collect DeltaNet state tensors (element 5) + deltanet_states = layer_outputs[5] if len(layer_outputs) > 5 else None + if deltanet_states is not None: + # deltanet_states = (new_recurrent_state, new_conv_state) + deltanet_state_tensors.append(deltanet_states[0]) # recurrent + deltanet_state_tensors.append(deltanet_states[1]) # conv + + # Update KV cache + if update_cache: + next_decoder_cache = self.kv_mgr.update_cache( + is_for_context_encoding=is_for_context_encoding, + seq_ids=seq_ids, + position_ids=position_ids, + new_key_values=next_decoder_cache, + seq_len=cache_size, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + # Store DeltaNet state tensors for forward() to append to output + self._deltanet_updated_states = deltanet_state_tensors + + return (hidden_states, next_decoder_cache) + + def forward( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden=None, + adapter_ids=None, + accepted_indices=None, + current_length=None, + medusa_mask=None, + scatter_index=None, + slot_mapping=None, + active_block_table=None, + num_queries=None, + computed_context_lens=None, + tile_q_indices=None, + tile_block_tables=None, + tile_masks=None, + inputs_embeds=None, + kv_cache=None, + active_mask=None, + rotary_position_id=None, + vision_embeddings=None, + vision_mask=None, + ): + """Override base forward to append DeltaNet state tensors to output. + + The base flow builds: [result] + [logits?] + updated_kv_cache + We add: + deltanet_state_tensors + + The input_output_aliases dict maps each DeltaNet state nn.Parameter + to its output index, which is after the KV cache entries. + """ + # Call parent forward to get the standard output + # We can't call super().forward() because we need to inject deltanet + # state tensors. Instead, replicate the relevant parts of the base forward. + + prev_hidden = self.set_none_if_empty(prev_hidden) + adapter_ids = self.set_none_if_empty(adapter_ids) + accepted_indices = self.set_none_if_empty(accepted_indices) + current_length = self.set_none_if_empty(current_length) + medusa_mask = self.set_none_if_empty(medusa_mask) + scatter_index = self.set_none_if_empty(scatter_index) + slot_mapping = self.set_none_if_empty(slot_mapping) + active_block_table = self.set_none_if_empty(active_block_table) + num_queries = self.set_none_if_empty(num_queries) + computed_context_lens = self.set_none_if_empty(computed_context_lens) + tile_q_indices = self.set_none_if_empty(tile_q_indices) + tile_block_tables = self.set_none_if_empty(tile_block_tables) + tile_masks = self.set_none_if_empty(tile_masks) + inputs_embeds = self.set_none_if_empty(inputs_embeds) + kv_cache = self.set_none_if_empty(kv_cache) + active_mask = self.set_none_if_empty(active_mask) + rotary_position_id = self.set_none_if_empty(rotary_position_id) + vision_embeddings = self.set_none_if_empty(vision_embeddings) + vision_mask = self.set_none_if_empty(vision_mask) + + is_for_context_encoding = position_ids.shape[-1] != 1 and not ( + hasattr(self.neuron_config, "speculation_length") + and position_ids.shape[-1] == self.neuron_config.speculation_length + ) + + seq_ids = seq_ids.to(torch.int32) + attn_mask = attention_mask + + hidden_states, updated_kv_cache = self.get_model_output( + input_ids=input_ids, + seq_ids=seq_ids, + attention_mask=attn_mask, + position_ids=position_ids, + active_mask=active_mask, + inputs_embeds=inputs_embeds, + adapter_ids=adapter_ids, + rotary_position_ids=rotary_position_id, + update_cache=True, + is_for_context_encoding=is_for_context_encoding, + padding_mask=None, + active_block_table=active_block_table, + scatter_index=slot_mapping + if getattr(self, "is_block_kv_layout", False) + else scatter_index, + vision_embeddings=vision_embeddings, + vision_mask=vision_mask, + ) + + batch_size = input_ids.shape[0] + if not getattr(self, "sliced_hidden", False): + if not is_for_context_encoding: + # Token generation: already (B, 1, H) from position_ids + pass + else: + # Context encoding: take last valid position + index = torch.max(position_ids, dim=1, keepdim=True).indices + index = index.unsqueeze(1).expand(batch_size, 1, self.hidden_size) + hidden_states = torch.gather(hidden_states, dim=1, index=index) + + logits = self.lm_head(hidden_states) + logits = logits.float() + + if hasattr(self.lm_head, "pad_size"): + if self.lm_head.gather_output: + rank_id = torch.tensor(0, device=logits.device, dtype=torch.int32) + world_size = 1 + else: + from neuronx_distributed.parallel_layers import parallel_state + + rank_id = self.rank_util.get_rank() + world_size = torch.distributed.get_world_size( + group=self.lm_head.tensor_parallel_group + ) + from neuronx_distributed_inference.models.model_base import ( + mask_padded_logits, + ) + + logits = mask_padded_logits( + logits, rank_id, world_size, pad_size=self.lm_head.pad_size + ) + + if self.on_device_sampling: + res = self._sample_on_device( + logits, sampling_params, False, is_for_context_encoding + ) + else: + res = logits + + outputs = [res] + if self.neuron_config.output_logits: + outputs += [logits] + outputs += updated_kv_cache + + # Append DeltaNet state tensors (for input_output_aliases) + if hasattr(self, "_deltanet_updated_states"): + outputs += self._deltanet_updated_states + + return outputs + + +# ============================================================ +# State Dict Converter +# ============================================================ + + +def convert_qwen35_hf_to_neuron_state_dict(neuron_state_dict, config): + """Convert HF Qwen3.5 weights to NxDI format. + + Weight mappings per layer type: + + DeltaNet layers (linear_attention): + HF: layers.X.linear_attn.{in_proj_qkv, in_proj_z, in_proj_a, in_proj_b, + conv1d, A_log, dt_bias, norm, out_proj} + NxDI: same names (no remapping needed) + + Full attention layers: + HF: layers.X.self_attn.q_proj.weight: (8192, 2048) -- doubled for gate + NxDI: layers.X.self_attn.Wqkv.weight (fused Q+K+V, gate separated) + layers.X.self_attn.output_gate_proj.weight (gate part) + HF: layers.X.self_attn.{k_proj, v_proj, o_proj, q_norm, k_norm} + NxDI: layers.X.self_attn.{..., q_layernorm, k_layernorm} + + MoE (all layers): + HF: layers.X.mlp.gate.weight -> NxDI: layers.X.mlp.router.linear_router.weight + HF: layers.X.mlp.experts.gate_up_proj -> NxDI: layers.X.mlp.expert_mlps.mlp_op.gate_up_proj.weight + HF: layers.X.mlp.experts.down_proj -> NxDI: layers.X.mlp.expert_mlps.mlp_op.down_proj.weight + HF: layers.X.mlp.shared_expert.{gate,up,down}_proj -> NxDI: layers.X.mlp.shared_experts.shared_experts.{gate,up,down}_proj + HF: layers.X.mlp.shared_expert_gate.weight -> NxDI: layers.X.mlp.shared_experts.sigmoid_gate.weight + """ + # Add rank_util + neuron_state_dict["rank_util.rank"] = torch.arange( + 0, + config.neuron_config.tp_degree, + dtype=torch.int32, + ) + + # CRITICAL V42: Zero out the pad token embedding. + # NxDI right-pads CTE inputs to bucket size with pad_token_id. DeltaNet has no + # causal attention mask, so it processes ALL positions including padding. + # XLA constant-folds any masking operation we put inside the traced model. + # The only solution is to make pad token embeddings ZERO at the weight level, + # so that padding positions contribute nothing to the DeltaNet recurrence + # (zero input → zero projection → zero key/value/query → no state update). + pad_token_id = config.pad_token_id + embed_key = "embed_tokens.weight" + if embed_key in neuron_state_dict and pad_token_id is not None: + emb_weight = neuron_state_dict[embed_key] + old_norm = emb_weight[pad_token_id].float().norm().item() + emb_weight[pad_token_id] = 0.0 + print( + f" [PAD EMBED] Zeroed embedding for pad_token_id={pad_token_id} " + f"(was norm={old_norm:.4f})" + ) + neuron_state_dict[embed_key] = emb_weight + + # CRITICAL: Convert (1+weight) RMSNorm weights to standard RMSNorm weights. + # Qwen3.5-MoE uses RMSNorm with `output = norm(x) * (1 + weight)` where weight + # is initialized to zeros. Standard NxDI RMSNorm uses `output = norm(x) * weight` + # where weight is initialized to ones. To convert: new_weight = old_weight + 1.0 + # This affects: input_layernorm, post_attention_layernorm, q_norm, k_norm, final norm + # but NOT the DeltaNet internal RMSNormGated (which uses standard weight * norm(x)) + norm_keys_to_convert = [] + for l in range(config.num_hidden_layers): + norm_keys_to_convert.append(f"layers.{l}.input_layernorm.weight") + norm_keys_to_convert.append(f"layers.{l}.post_attention_layernorm.weight") + if config.layer_types[l] == "full_attention": + norm_keys_to_convert.append(f"layers.{l}.self_attn.q_norm.weight") + norm_keys_to_convert.append(f"layers.{l}.self_attn.k_norm.weight") + norm_keys_to_convert.append("norm.weight") + + for nk in norm_keys_to_convert: + if nk in neuron_state_dict: + old_val = neuron_state_dict[nk] + neuron_state_dict[nk] = old_val.float() + 1.0 + if "layers.0." in nk or nk == "norm.weight": + print( + f" [NORM FIX] {nk}: mean {old_val.float().mean():.4f} -> {neuron_state_dict[nk].mean():.4f}" + ) + else: + if "layers.0." in nk or nk == "norm.weight": + print(f" [NORM FIX] WARNING: key not found: {nk}") + print( + f" Available keys (sample): {[k for k in neuron_state_dict.keys() if 'norm' in k.lower()][:5]}" + ) + + for l in range(config.num_hidden_layers): + layer_type = config.layer_types[l] + + # === Attention layers === + if layer_type == "full_attention": + neuron_state_dict[f"layers.{l}.self_attn.rank_util.rank"] = torch.arange( + 0, + config.neuron_config.tp_degree, + dtype=torch.int32, + ) + + # QK norms: q_norm -> q_layernorm, k_norm -> k_layernorm + q_norm_key = f"layers.{l}.self_attn.q_norm.weight" + k_norm_key = f"layers.{l}.self_attn.k_norm.weight" + if q_norm_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.self_attn.q_layernorm.weight"] = ( + neuron_state_dict.pop(q_norm_key).detach().clone() + ) + if k_norm_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.self_attn.k_layernorm.weight"] = ( + neuron_state_dict.pop(k_norm_key).detach().clone() + ) + + # q_proj is doubled: (8192, 2048) = (num_heads * head_dim * 2, hidden) + # The weight is INTERLEAVED by head: + # [head0_query(256) | head0_gate(256) | head1_query(256) | head1_gate(256) | ...] + # We need to deinterleave into separate query and gate weights. + q_proj_key = f"layers.{l}.self_attn.q_proj.weight" + if q_proj_key in neuron_state_dict: + q_proj_w = neuron_state_dict.pop(q_proj_key) + num_heads = config.num_attention_heads # 16 + head_dim = config.head_dim # 256 + # Reshape to (num_heads, head_dim*2, hidden_size) + q_proj_w = q_proj_w.reshape(num_heads, head_dim * 2, config.hidden_size) + # Split each head's output into query and gate + query_w = q_proj_w[:, :head_dim, :] # (16, 256, 2048) + gate_w = q_proj_w[:, head_dim:, :] # (16, 256, 2048) + # Reshape back to (num_heads * head_dim, hidden_size) + query_w = query_w.reshape( + num_heads * head_dim, config.hidden_size + ) # (4096, 2048) + gate_w = gate_w.reshape( + num_heads * head_dim, config.hidden_size + ) # (4096, 2048) + + # Store query part back as q_proj for Wqkv fusion + neuron_state_dict[q_proj_key] = query_w + # Store gate weights for the output_gate_proj ColumnParallelLinear + neuron_state_dict[f"layers.{l}.self_attn.output_gate_proj.weight"] = ( + gate_w + ) + + # Fuse QKV + if config.neuron_config.fused_qkv: + q_key = f"layers.{l}.self_attn.q_proj.weight" + k_key = f"layers.{l}.self_attn.k_proj.weight" + v_key = f"layers.{l}.self_attn.v_proj.weight" + if q_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = torch.cat( + [ + neuron_state_dict[q_key], + neuron_state_dict[k_key], + neuron_state_dict[v_key], + ] + ) + del neuron_state_dict[q_key] + del neuron_state_dict[k_key] + del neuron_state_dict[v_key] + + # === DeltaNet layers: deinterleave in_proj_qkvz and in_proj_ba === + elif layer_type == "linear_attention": + # HF stores `in_proj_qkvz` as (12288, 2048) in GROUPED/INTERLEAVED format: + # (num_k_heads=16, per_group=768, hidden) where each group = [q(128), k(128), v(256), z(256)] + # We need: in_proj_qkv (8192, 2048) = flat [all_Q | all_K | all_V] + # in_proj_z (4096, 2048) = flat [all_Z] + qkvz_key = f"layers.{l}.linear_attn.in_proj_qkvz.weight" + if qkvz_key in neuron_state_dict: + w = neuron_state_dict.pop(qkvz_key) + num_k_heads = config.linear_num_key_heads # 16 + head_k_dim = config.linear_key_head_dim # 128 + head_v_dim = config.linear_value_head_dim # 128 + num_v_heads = config.linear_num_value_heads # 32 + # Each group: q(head_k_dim=128) + k(head_k_dim=128) + v(head_v_dim*2=256) + z(head_v_dim*2=256) = 768 + v_per_group = head_v_dim * (num_v_heads // num_k_heads) # 128 * 2 = 256 + z_per_group = v_per_group # 256 + group_size = head_k_dim + head_k_dim + v_per_group + z_per_group # 768 + w = w.reshape(num_k_heads, group_size, config.hidden_size) + q_parts = w[:, :head_k_dim, :] # (16, 128, 2048) + k_parts = w[:, head_k_dim : head_k_dim * 2, :] # (16, 128, 2048) + v_parts = w[ + :, head_k_dim * 2 : head_k_dim * 2 + v_per_group, : + ] # (16, 256, 2048) + z_parts = w[:, head_k_dim * 2 + v_per_group :, :] # (16, 256, 2048) + # Flatten: (16, dim, 2048) -> (16*dim, 2048) + qkv_w = torch.cat( + [ + q_parts.reshape(-1, config.hidden_size), # (2048, 2048) + k_parts.reshape(-1, config.hidden_size), # (2048, 2048) + v_parts.reshape(-1, config.hidden_size), # (4096, 2048) + ], + dim=0, + ) # (8192, 2048) + z_w = z_parts.reshape(-1, config.hidden_size) # (4096, 2048) + neuron_state_dict[f"layers.{l}.linear_attn.in_proj_qkv.weight"] = qkv_w + neuron_state_dict[f"layers.{l}.linear_attn.in_proj_z.weight"] = z_w + if l == 0: + print( + f" [DELTANET] Deinterleaved in_proj_qkvz ({list(neuron_state_dict.get(qkvz_key, w).shape)}) -> in_proj_qkv {tuple(qkv_w.shape)}, in_proj_z {tuple(z_w.shape)}" + ) + + # HF stores `in_proj_ba` as (64, 2048) in grouped format: + # (num_k_heads=16, 4, hidden) where each group = [b(2), a(2)] + # We need: in_proj_b (32, 2048) and in_proj_a (32, 2048) + ba_key = f"layers.{l}.linear_attn.in_proj_ba.weight" + if ba_key in neuron_state_dict: + w = neuron_state_dict.pop(ba_key) + num_k_heads = config.linear_num_key_heads # 16 + num_v_heads = config.linear_num_value_heads # 32 + heads_per_group = num_v_heads // num_k_heads # 2 + # Each group: b(heads_per_group=2) + a(heads_per_group=2) = 4 + w = w.reshape(num_k_heads, 2 * heads_per_group, config.hidden_size) + b_parts = w[:, :heads_per_group, :] # (16, 2, 2048) + a_parts = w[:, heads_per_group:, :] # (16, 2, 2048) + b_w = b_parts.reshape(-1, config.hidden_size) # (32, 2048) + a_w = a_parts.reshape(-1, config.hidden_size) # (32, 2048) + neuron_state_dict[f"layers.{l}.linear_attn.in_proj_b.weight"] = b_w + neuron_state_dict[f"layers.{l}.linear_attn.in_proj_a.weight"] = a_w + if l == 0: + print( + f" [DELTANET] Deinterleaved in_proj_ba ({list(w.shape)}) -> in_proj_b {tuple(b_w.shape)}, in_proj_a {tuple(a_w.shape)}" + ) + + # === MoE weights (ALL layers have MoE MLPs) === + # Fuse individual experts if stored separately (per-expert weights) + # Some HF checkpoints store experts individually as layers.X.mlp.experts.N.{gate,up,down}_proj + # Check if we need to fuse them + expert0_gate_key = f"layers.{l}.mlp.experts.0.gate_proj.weight" + if expert0_gate_key in neuron_state_dict: + num_experts = config.num_experts # 512 + gate_projs = [] + up_projs = [] + down_projs = [] + for e in range(num_experts): + gate_projs.append( + neuron_state_dict.pop( + f"layers.{l}.mlp.experts.{e}.gate_proj.weight" + ) + ) + up_projs.append( + neuron_state_dict.pop(f"layers.{l}.mlp.experts.{e}.up_proj.weight") + ) + down_projs.append( + neuron_state_dict.pop( + f"layers.{l}.mlp.experts.{e}.down_proj.weight" + ) + ) + # Stack: (E, intermediate, hidden) for gate/up, (E, hidden, intermediate) for down + gate_up = torch.cat([torch.stack(gate_projs), torch.stack(up_projs)], dim=1) + # gate_up: (E, 2*I, H), need (E, H, 2*I) for NxDI + gate_up = gate_up.permute(0, 2, 1).contiguous() + down = torch.stack(down_projs) + # down: (E, H, I), need (E, I, H) for NxDI + down = down.permute(0, 2, 1).contiguous() + neuron_state_dict[ + f"layers.{l}.mlp.expert_mlps.mlp_op.gate_up_proj.weight" + ] = gate_up + neuron_state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.down_proj.weight"] = ( + down + ) + if l == 0: + print( + f" [MOE] Fused {num_experts} individual experts -> gate_up_proj {list(gate_up.shape)}, down_proj {list(down.shape)}" + ) + + # === MoE weights === + # Router + gate_key = f"layers.{l}.mlp.gate.weight" + if gate_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.mlp.router.linear_router.weight"] = ( + neuron_state_dict.pop(gate_key).detach().clone() + ) + + # Fused expert weights + # HF pre-fused: experts.gate_up_proj (E, 2*I, H) -- need transpose to NxDI (E, H, 2*I) + # HF pre-fused: experts.down_proj (E, H, I) -- need transpose to NxDI (E, I, H) + gate_up_key = f"layers.{l}.mlp.experts.gate_up_proj" + down_key = f"layers.{l}.mlp.experts.down_proj" + + if gate_up_key in neuron_state_dict: + w = neuron_state_dict.pop(gate_up_key).detach().clone() + # Transpose: (E, 2*I, H) -> (E, H, 2*I) + w = w.permute(0, 2, 1).contiguous() + # Apply padding if needed + pad_size = getattr(config, "moe_intermediate_pad_size", 0) + if pad_size > 0: + I = w.shape[2] // 2 + w = w.reshape(config.num_experts, config.hidden_size, 2, I) + w = torch.nn.functional.pad(w, (0, pad_size)) + w = w.reshape(config.num_experts, config.hidden_size, -1) + neuron_state_dict[ + f"layers.{l}.mlp.expert_mlps.mlp_op.gate_up_proj.weight" + ] = w + + if down_key in neuron_state_dict: + w = neuron_state_dict.pop(down_key).detach().clone() + # Transpose: (E, H, I) -> (E, I, H) + w = w.permute(0, 2, 1).contiguous() + # Apply padding if needed + pad_size = getattr(config, "moe_intermediate_pad_size", 0) + if pad_size > 0: + w = torch.nn.functional.pad(w, (0, 0, 0, pad_size)) + neuron_state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.down_proj.weight"] = w + + # Shared expert weights (SigmoidGatedSharedExperts wrapping NxDI SharedExperts) + # HF: mlp.shared_expert.{gate_proj, up_proj, down_proj} + # NxDI: mlp.shared_experts.shared_experts.{gate_proj, up_proj, down_proj} + # (mlp.shared_experts is SigmoidGatedSharedExperts, inner .shared_experts is NxDI SharedExperts) + for proj in ["gate_proj", "up_proj", "down_proj"]: + hf_key = f"layers.{l}.mlp.shared_expert.{proj}.weight" + nxdi_key = f"layers.{l}.mlp.shared_experts.shared_experts.{proj}.weight" + if hf_key in neuron_state_dict: + neuron_state_dict[nxdi_key] = ( + neuron_state_dict.pop(hf_key).detach().clone() + ) + + # Shared expert sigmoid gate: mlp.shared_expert_gate.weight + # -> mlp.shared_experts.sigmoid_gate.weight + seg_key = f"layers.{l}.mlp.shared_expert_gate.weight" + if seg_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.mlp.shared_experts.sigmoid_gate.weight"] = ( + neuron_state_dict.pop(seg_key).detach().clone() + ) + + # Fused MoE TKG aliased weights + if getattr(config, "moe_fused_nki_kernel_enabled", False): + # MoEFusedTKG has a separate (non-shared) RMSNorm that + # needs the same weights as post_attention_layernorm. + post_attn_key = f"layers.{l}.post_attention_layernorm.weight" + if post_attn_key in neuron_state_dict: + neuron_state_dict[ + f"layers.{l}.mlp.moe_fused_tkg.post_attention_layernorm.weight" + ] = neuron_state_dict[post_attn_key].clone() + + # Router transposed weight (required by fused TKG kernel) + router_key = f"layers.{l}.mlp.router.linear_router.weight" + if router_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.mlp.router.weight_T"] = ( + neuron_state_dict[router_key].detach().T.clone() + ) + + gc.collect() + + return neuron_state_dict + + +# ============================================================ +# Custom ModelWrapper and DecoderModelInstance for DeltaNet state aliasing +# ============================================================ + + +class Qwen35DecoderModelInstance(DecoderModelInstance): + """Custom DecoderModelInstance that adds DeltaNet state buffers to input_output_aliases. + + After the standard KV cache aliases, we add aliases for each DeltaNet layer's + recurrent_state_buffer and conv_state_buffer. This allows the XLA runtime to + carry state between CTE and TKG graphs via shared HBM buffers. + """ + + def get(self, bucket_rank, **kwargs): + """Override to add DeltaNet state aliases after KV cache aliases.""" + module, input_output_aliases = super().get(bucket_rank, **kwargs) + + # After super().get(), input_output_aliases maps KV cache params to + # output indices starting from num_output_from_trace. + # DeltaNet states go after all KV cache entries. + num_output_from_trace = 1 if not self.neuron_config.output_logits else 2 + + # Count KV cache entries + if module.kv_mgr is not None: + num_kv = len(module.kv_mgr.past_key_values) + else: + num_kv = 0 + + # DeltaNet state aliases start after KV cache + state_start_idx = num_output_from_trace + num_kv + + # Add aliases for DeltaNet state buffers + if hasattr(module, "_deltanet_state_params"): + for i, param in enumerate(module._deltanet_state_params): + input_output_aliases[param] = state_start_idx + i + + return module, input_output_aliases + + +class Qwen35ModelWrapper(ModelWrapper): + """Custom ModelWrapper that uses Qwen35DecoderModelInstance. + + Overrides input_generator to add vision_embeddings and vision_mask + as traced inputs for VL support. + """ + + def get_model_instance(self): + return Qwen35DecoderModelInstance( + model_cls=self.model_cls, + config=self.config, + **self.model_init_kwargs, + ) + + def input_generator(self): + """Generate inputs including mrope_position_ids, vision_embeddings, and vision_mask. + + Layout depends on whether prefix caching is enabled: + + WITHOUT prefix caching (24 args): + - Positions 0-6: standard NxDI (input_ids, attn_mask, pos_ids, seq_ids, sampling, prev_hidden, adapter) + - Positions 7-20: empty tensors (unused NxDI slots) + - Position 21: rotary_position_id = mrope_position_ids (3, BS, seq_len) for CTE, empty for TKG + - Position 22: vision_embeddings (BS, seq_len, hidden_size) for CTE, empty for TKG + - Position 23: vision_mask (BS, seq_len, 1) for CTE, empty for TKG + + WITH prefix caching (24 args, different layout): + - Positions 0-6: standard (input_ids, attn_mask, pos_ids, seq_ids, sampling, prev_hidden, adapter) + - Positions 7-10: empty (medusa slots) + - Position 11: slot_mapping (BS, n_active_tokens) + - Position 12: active_block_table (BS, num_blocks) or (1,) if no prefix + - Position 13: num_queries (BS, 1) + - Position 14: computed_context_lens (BS, 1) + - Positions 15-20: empty + - Position 21: rotary_position_id = mrope_position_ids + - Position 22: vision_embeddings + - Position 23: vision_mask + """ + if self.is_prefix_caching: + return self._input_generator_prefix_caching() + + base_inputs = super().input_generator() + extended_inputs = [] + + for bucket_inputs in base_inputs: + input_ids = bucket_inputs[0] + batch_size = input_ids.shape[0] + n_active_tokens = input_ids.shape[1] + + is_cte = n_active_tokens > 1 + + if is_cte: + # CRITICAL V42: Override position_ids to simulate padding during CTE tracing. + # The DeltaNet padding mask uses `position_ids >= arange(S)` to detect + # valid vs padding positions. If we trace with position_ids=[0,1,...,S-1] + # (all valid), XLA constant-folds the mask to all-True and eliminates the + # masking code from the NEFF. By tracing with simulated padding + # (position_ids has zeros in padding positions), XLA keeps the masking + # operations because it can't prove the mask is always True. + # + # Use half-valid: [0, 1, ..., S/2-1, 0, 0, ..., 0] + # Padding value is 0 (NOT 1!) so that: + # - mask = pos_ids >= arange gives [T,T,...,T,F,F,...,F] (correct) + # - torch.max(pos_ids) gives index of last valid position (correct) + half = n_active_tokens // 2 + # Must match base class dtypes: attention_mask=int32, position_ids=int64 + base_pos_dtype = bucket_inputs[2].dtype # match whatever base uses + base_mask_dtype = bucket_inputs[1].dtype # match whatever base uses + trace_pos_ids = ( + torch.cat( + [ + torch.arange(half, dtype=base_pos_dtype), + torch.zeros(n_active_tokens - half, dtype=base_pos_dtype), + ] + ) + .unsqueeze(0) + .expand(batch_size, -1) + .contiguous() + ) + # Also override attention_mask to match (1s for valid, 0s for padding) + trace_attn_mask = ( + torch.cat( + [ + torch.ones(half, dtype=base_mask_dtype), + torch.zeros(n_active_tokens - half, dtype=base_mask_dtype), + ] + ) + .unsqueeze(0) + .expand(batch_size, -1) + .contiguous() + ) + # Replace in bucket_inputs (index 1=attn_mask, index 2=pos_ids) + bucket_inputs = list(bucket_inputs) + bucket_inputs[1] = trace_attn_mask + bucket_inputs[2] = trace_pos_ids + bucket_inputs = tuple(bucket_inputs) + # Context encoding: properly-shaped inputs + # mRoPE position IDs: (3, BS, seq_len) -- T/H/W all sequential for trace + mrope_position_ids = ( + torch.arange(0, n_active_tokens, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + + vision_embeddings = torch.zeros( + (batch_size, n_active_tokens, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + (batch_size, n_active_tokens, 1), + fill_value=n_active_tokens + - 1, # Safe fill: scatter to last position + dtype=torch.int32, + ) + else: + # Token generation: empty tensors + mrope_position_ids = torch.zeros((0,), dtype=torch.int32) + vision_embeddings = torch.zeros( + (0,), dtype=self.config.neuron_config.torch_dtype + ) + vision_mask = torch.zeros((0,), dtype=torch.int32) + + # Base generates 7 args; pad to 21, then add mrope + vision + padded = list(bucket_inputs) + while len(padded) < 21: + padded.append(torch.zeros((0,), dtype=torch.int32)) + padded.append(mrope_position_ids) # position 21: rotary_position_id + padded.append(vision_embeddings) # position 22 + padded.append(vision_mask) # position 23 + + extended_inputs.append(tuple(padded)) + + return extended_inputs + + def _input_generator_prefix_caching(self): + """Generate trace inputs for prefix caching mode. + + Uses the base class prefix caching layout (positions 0-14) which aligns + with the model's forward() signature for block KV parameters, then adds + mRoPE and vision args at positions 21-23. + + Layout: + 0: input_ids, 1: attention_mask, 2: position_ids, 3: seq_ids, + 4: sampling_params, 5: prev_hidden (empty), 6: adapter_ids, + 7-10: empties (medusa slots), + 11: slot_mapping, 12: active_block_table, 13: num_queries, + 14: computed_context_lens, 15-20: empties, + 21: mrope_position_ids, 22: vision_embeddings, 23: vision_mask + """ + # Get base prefix caching inputs (positions 0-14) + base_inputs = super().input_generator() + extended_inputs = [] + + for bucket_inputs in base_inputs: + # base_inputs already has the prefix caching layout from + # _get_input_shape_for_prefix_caching: 15 args (0-14) + input_ids = bucket_inputs[0] + batch_size = input_ids.shape[0] + n_active_tokens = input_ids.shape[1] + is_cte = n_active_tokens > 1 + + if is_cte: + mrope_position_ids = ( + torch.arange(0, n_active_tokens, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + vision_embeddings = torch.zeros( + (batch_size, n_active_tokens, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + (batch_size, n_active_tokens, 1), + fill_value=n_active_tokens - 1, + dtype=torch.int32, + ) + else: + mrope_position_ids = torch.zeros((0,), dtype=torch.int32) + vision_embeddings = torch.zeros( + (0,), dtype=self.config.neuron_config.torch_dtype + ) + vision_mask = torch.zeros((0,), dtype=torch.int32) + + # Start from base prefix caching inputs (already has 15 args: 0-14) + padded = list(bucket_inputs) + # Pad positions 15-20 with empties + while len(padded) < 21: + padded.append(torch.zeros((0,), dtype=torch.int32)) + # Add Qwen3.5-specific args at positions 21-23 + padded.append(mrope_position_ids) # position 21: rotary_position_id + padded.append(vision_embeddings) # position 22 + padded.append(vision_mask) # position 23 + + extended_inputs.append(tuple(padded)) + + return extended_inputs + + def pad_inputs(self, *args, pad_type="first_fit"): + """Override to pad mrope_position_ids and vision inputs to bucket size. + + CRITICAL FIX: The base class pad_inputs() (model_wrapper.py line 831) + has a code path that REGENERATES vision embeddings as all-zeros when + it detects 24 args with vision_mask shape != pad_length. This destroys + the real vision data BEFORE our override gets to work on it. + + Solution: Save the ORIGINAL vision args (positions 21-23) BEFORE + calling super().pad_inputs(), then use those originals for + zero-extension padding afterward. + + V43 FIX: The base class pads position_ids with value=1. This is CORRECT: + - For 1-token: pos_ids=[0,1,1,...,1], max=1 → gather position 1. + Position 1 has DeltaNet state from position 0 and produces non-zero output. + - For N-token: pos_ids=[0,...,N-1,1,1,...], max=N-1 → gather last valid. + + V42 incorrectly re-padded position_ids with 0. With pad_value=0: + pos_ids=[0,0,...,0], max=0 → always gathers position 0 which produces + all-zero logits (position 0's hidden state is destroyed by the model + processing when there are many zero-embedding padding positions). + """ + # Save original vision args BEFORE the base class destroys them + orig_mrope = args[21] if len(args) >= 22 else None + orig_vis_emb = args[22] if len(args) >= 23 else None + orig_vis_mask = args[23] if len(args) >= 24 else None + + # Let base class pad positions 0-2 (input_ids, attention_mask, position_ids) + # NOTE: base class will zero out positions 22-23, but we saved originals above + # NOTE: base class pads position_ids with value=1. This is CORRECT: + # - For 1-token: pos_ids=[0,1,1,...,1], max=1 → gather position 1. + # Position 0 always produces zeros (unknown reason, possibly related to + # how the NEFF initializes the first position). Position 1 receives DeltaNet + # state propagated from position 0 and produces meaningful output. + # - For N-token: pos_ids=[0,1,...,N-1,1,1,...], max=N-1 → gather last valid. + # V42 incorrectly re-padded with 0, causing torch.max to always return index 0 + # which gives all-zero logits for single-token inputs. + padded_args = super().pad_inputs(*args, pad_type=pad_type) + + # Check if padding is needed (CTE only, when we have 24 args) + if len(padded_args) >= 24 and orig_mrope is not None: + padded_seq_len = padded_args[0].shape[1] + batch_size = padded_args[0].shape[0] + is_cte = padded_seq_len > 1 + + if is_cte: + # Use ORIGINALS (not the base-class-zeroed versions) + current_mrope = orig_mrope + current_vis_emb = orig_vis_emb + current_vis_mask = orig_vis_mask + + # Pad mrope_position_ids: (3, BS, orig_len) -> (3, BS, padded_len) + if ( + current_mrope.ndim == 3 + and current_mrope.shape[-1] != padded_seq_len + ): + orig_len = current_mrope.shape[-1] + pad_size = padded_seq_len - orig_len + last_pos = current_mrope[:, :, -1:] # (3, BS, 1) + pad_offsets = torch.arange( + 1, pad_size + 1, dtype=current_mrope.dtype + ) + pad_offsets = ( + pad_offsets.unsqueeze(0).unsqueeze(0).expand(3, batch_size, -1) + ) + mrope_pad = last_pos + pad_offsets + mrope_position_ids = torch.cat([current_mrope, mrope_pad], dim=-1) + elif current_mrope.ndim == 3: + mrope_position_ids = current_mrope + else: + # Fallback: generate sequential (text-only tracing) + mrope_position_ids = ( + torch.arange(0, padded_seq_len, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + + # Pad vision_embeddings: (BS, orig_len, H) -> (BS, padded_len, H) + # Extend with zeros (padding tokens have no vision content) + if ( + current_vis_emb is not None + and current_vis_emb.ndim == 3 + and current_vis_emb.shape[1] < padded_seq_len + ): + pad_emb = torch.zeros( + ( + batch_size, + padded_seq_len - current_vis_emb.shape[1], + current_vis_emb.shape[2], + ), + dtype=current_vis_emb.dtype, + ) + vision_embeddings = torch.cat([current_vis_emb, pad_emb], dim=1) + elif current_vis_emb is not None and current_vis_emb.ndim == 3: + vision_embeddings = current_vis_emb[:, :padded_seq_len] + else: + vision_embeddings = torch.zeros( + (batch_size, padded_seq_len, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + + # Pad vision_mask: (BS, orig_len, 1) -> (BS, padded_len, 1) + # CRITICAL FIX (V52g): For text-only prompts, the vision_mask is + # initially created with fill_value=orig_seq_len-1. After padding, + # this points to a VALID content position, causing + # encode_vision_to_input to scatter zeros over real embeddings. + # Fix: detect if vision_embeddings are all zeros (text-only case) + # and replace ALL vision_mask values with padded_seq_len-1 so the + # scatter targets only padding positions (which are already zeros). + is_text_only = ( + current_vis_emb is not None + and current_vis_emb.ndim == 3 + and current_vis_emb.abs().sum().item() == 0 + ) + + if ( + current_vis_mask is not None + and current_vis_mask.ndim == 3 + and current_vis_mask.shape[1] < padded_seq_len + ): + if is_text_only: + # Text-only: ALL positions should scatter to last padding slot + vision_mask = torch.full( + (batch_size, padded_seq_len, 1), + fill_value=padded_seq_len - 1, + dtype=torch.int32, + ) + else: + # Real vision: keep original positions, extend with safe target + pad_mask = torch.full( + (batch_size, padded_seq_len - current_vis_mask.shape[1], 1), + fill_value=padded_seq_len - 1, + dtype=torch.int32, + ) + vision_mask = torch.cat([current_vis_mask, pad_mask], dim=1) + elif current_vis_mask is not None and current_vis_mask.ndim == 3: + if is_text_only: + # Already at padded size but values may be wrong + vision_mask = torch.full( + (batch_size, padded_seq_len, 1), + fill_value=padded_seq_len - 1, + dtype=torch.int32, + ) + else: + vision_mask = current_vis_mask[:, :padded_seq_len] + else: + vision_mask = torch.full( + (batch_size, padded_seq_len, 1), + fill_value=padded_seq_len - 1, + dtype=torch.int32, + ) + + padded_args = ( + *padded_args[:21], + mrope_position_ids, + vision_embeddings, + vision_mask, + ) + + # Safety clamp: ensure all vision_mask entries are within valid range. + # This is a no-op when fill_value is already seq_len-1, but protects + # against any edge case where values exceed the tensor dimensions. + padded_args = list(padded_args) + padded_args[23] = padded_args[23].clamp(max=padded_seq_len - 1) + padded_args = tuple(padded_args) + + return padded_args + + def _forward_with_pad(self, *args): + """Override to include Qwen3.5-specific args (positions 15-23) for TKG. + + The base class _forward_with_pad only builds padded_args for positions + 0-14 when is_prefix_caching=True, then calls self._forward(*padded_args). + Our model has 24 traced args: 15-20 are empties, 21 is mrope_position_ids, + 22 is vision_embeddings, 23 is vision_mask. + + For TKG decode, these are all empty/zero tensors that don't need batch + padding. We intercept _forward to append them. + """ + if not self.is_prefix_caching or len(args) <= 15: + return super()._forward_with_pad(*args) + + # Save extra args (positions 15-23) that base class will ignore + extra_args = list(args[15:]) + + # Temporarily wrap _forward to append extra args + orig_forward = self._forward + + def _forward_with_extra(*padded_args): + full_args = list(padded_args) + extra_args + return orig_forward(*full_args) + + self._forward = _forward_with_extra + try: + result = super()._forward_with_pad(*args) + finally: + self._forward = orig_forward + + return result + + def _pad_prefix_caching_inputs(self, *args, pad_type="first_fit"): + """Override to additionally pad mRoPE and vision args for prefix caching. + + The base class handles positions 0-14 (input_ids, attn_mask, pos_ids, + slot_mapping, block_table padding). We additionally pad: + - Position 21: mrope_position_ids (3, BS, seq_len) → pad seq_len dim + - Position 22: vision_embeddings (BS, seq_len, H) → pad seq_len dim + - Position 23: vision_mask (BS, seq_len, 1) → pad seq_len dim + """ + # Let base class handle standard prefix caching padding (positions 0-14) + padded_args = super()._pad_prefix_caching_inputs(*args, pad_type=pad_type) + + # If this is CTE and we have 24 args, pad the Qwen3.5-specific args + if ( + len(padded_args) >= 24 + and self.tag == CONTEXT_ENCODING_MODEL_TAG + and padded_args[0].shape[1] > 1 # is CTE + ): + padded_seq_len = padded_args[0].shape[1] + batch_size = padded_args[0].shape[0] + + # Pad mrope_position_ids at position 21: (3, BS, orig_len) → (3, BS, padded_len) + mrope = padded_args[21] + if mrope.ndim == 3 and mrope.shape[-1] != padded_seq_len: + orig_len = mrope.shape[-1] + if orig_len < padded_seq_len: + pad_size = padded_seq_len - orig_len + last_pos = mrope[:, :, -1:] + pad_offsets = torch.arange(1, pad_size + 1, dtype=mrope.dtype) + pad_offsets = ( + pad_offsets.unsqueeze(0).unsqueeze(0).expand(3, batch_size, -1) + ) + mrope_pad = last_pos + pad_offsets + mrope = torch.cat([mrope, mrope_pad], dim=-1) + else: + mrope = mrope[:, :, :padded_seq_len] + + # Pad vision_embeddings at position 22: (BS, orig_len, H) → (BS, padded_len, H) + vis_emb = padded_args[22] + if vis_emb.ndim == 3 and vis_emb.shape[1] != padded_seq_len: + if vis_emb.shape[1] < padded_seq_len: + pad_emb = torch.zeros( + ( + batch_size, + padded_seq_len - vis_emb.shape[1], + vis_emb.shape[2], + ), + dtype=vis_emb.dtype, + ) + vis_emb = torch.cat([vis_emb, pad_emb], dim=1) + else: + vis_emb = vis_emb[:, :padded_seq_len] + + # Pad vision_mask at position 23: (BS, orig_len, 1) → (BS, padded_len, 1) + vis_mask = padded_args[23] + if vis_mask.ndim == 3 and vis_mask.shape[1] != padded_seq_len: + if vis_mask.shape[1] < padded_seq_len: + pad_mask = torch.full( + (batch_size, padded_seq_len - vis_mask.shape[1], 1), + fill_value=padded_seq_len - 1, + dtype=torch.int32, + ) + vis_mask = torch.cat([vis_mask, pad_mask], dim=1) + else: + vis_mask = vis_mask[:, :padded_seq_len] + + padded_args = (*padded_args[:21], mrope, vis_emb, vis_mask) + + return padded_args + + +# ============================================================ +# Top-Level Model +# ============================================================ + + +class NeuronQwen35MoeForCausalLM(NeuronBaseForCausalLM): + _model_cls = NeuronQwen35MoeModel + + def get_model_wrapper_cls(self): + """Return custom ModelWrapper with DeltaNet state aliasing.""" + return Qwen35ModelWrapper + + @staticmethod + def load_hf_model(model_path, **kwargs): + """Load HF model weights. + + The model is a VL model (Qwen3_5MoeForConditionalGeneration) but we + only need the text backbone. We load with AutoModelForCausalLM which + will load the full model, then strip in convert_hf_to_neuron_state_dict. + """ + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_pretrained(model_path, **kwargs) + + @classmethod + def get_config_cls(cls): + return Qwen35MoeInferenceConfig + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict, config): + """Strip VL wrapper prefix and convert to NxDI format. + + The NxDI base class strips 'model.' prefix before calling this method. + So HF keys like 'model.language_model.layers.X...' arrive as + 'language_model.layers.X...'. We strip the 'language_model.' prefix here. + """ + new_sd = {} + for k, v in state_dict.items(): + # After base class strips 'model.', VL wrapper keys start with 'language_model.' + if k.startswith("language_model."): + new_k = k.replace("language_model.", "", 1) + new_sd[new_k] = v + # Handle case where 'model.' was NOT stripped (e.g., called directly) + elif k.startswith("model.language_model."): + new_k = k.replace("model.language_model.", "", 1) + new_sd[new_k] = v + elif k.startswith("model.visual") or k.startswith("visual"): + continue # Skip vision encoder + elif k.startswith("model."): + new_sd[k.replace("model.", "", 1)] = v + elif k.startswith("mtp."): + continue # Skip MTP + elif k.startswith("lm_head."): + new_sd[k] = v + else: + new_sd[k] = v + + return convert_qwen35_hf_to_neuron_state_dict(new_sd, config) + + def enable_context_encoding(self): + self.compile_tag = CONTEXT_ENCODING_MODEL_TAG + super().enable_context_encoding() + + def enable_token_generation(self): + self.compile_tag = TOKEN_GENERATION_MODEL_TAG + super().enable_token_generation() + + def _copy_past_key_values(self, outputs): + """Override to also copy DeltaNet state buffers on CPU. + + On Neuron, input_output_aliases handles this automatically. + On CPU, we must manually copy the output tensors back to the + nn.Parameter .data attributes on both CTE and TKG models. + """ + # First, call parent to copy KV cache + super()._copy_past_key_values(outputs) + + # Then copy DeltaNet state buffers + # The output layout is: [result] + [logits?] + kv_cache + deltanet_states + num_output_from_trace = 1 + if ( + self.neuron_config.output_logits + and self.neuron_config.on_device_sampling_config + ): + num_output_from_trace = 2 + + # Count KV cache entries + if ( + hasattr(self, "token_generation_model") + and self.token_generation_model is not None + ): + tkg_model = self.token_generation_model.model + cte_model = self.context_encoding_model.model + else: + return + + if tkg_model.kv_mgr is not None: + num_kv = len(tkg_model.kv_mgr.past_key_values) + else: + num_kv = 0 + + # DeltaNet states start after KV cache + state_start = num_output_from_trace + num_kv + + # Get the state params from both models + tkg_params = getattr(tkg_model, "_deltanet_state_params", []) + cte_params = getattr(cte_model, "_deltanet_state_params", []) + + if len(tkg_params) > 0 and state_start + len(tkg_params) <= len(outputs): + for i, (tkg_param, cte_param) in enumerate(zip(tkg_params, cte_params)): + new_state = outputs[state_start + i] + tkg_param.data = new_state + cte_param.data = new_state + + def get_required_kwargs(self): + """Return extra kwargs that must be propagated through the HF generation loop. + + This ensures llava_args (vision_embeddings + vision_mask) flows from + generate() -> prepare_inputs_for_generation() -> forward() -> _get_model_outputs(). + """ + return ["llava_args"] + + def _get_model_outputs( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden, + adapter_ids, + medusa_args, + llava_args, + slot_mapping=None, + block_table=None, + full_context_lens=None, + computed_context_lens=None, + tf_args=None, + ): + """Override to pass all 24 positional args explicitly. + + The model is traced with 24 positional args (from Qwen35ModelWrapper.input_generator). + The base class splats *llava_args after 7 args, which puts vision inputs at positions + 7-8 instead of 22-23. We override to fill positions 7-20 with torch.empty(0) and + place mrope_position_ids at 21 and vision inputs at 22-23. + + llava_args layout from VL generate(): + [0] vision_embeddings (BS, seq_len, hidden_size) + [1] vision_mask (BS, seq_len, 1) + [2] mrope_position_ids (3, batch, seq_len) -- optional + + For CTE: slot 21 = (3, B, S) mRoPE position IDs. + If not in llava_args, generate sequential IDs with T=H=W (text-only). + For TKG: slot 21 = torch.zeros((0,)) → set_none_if_empty → None → uses 2D position_ids. + """ + # --- PREFIX CACHING PATH --- + # When prefix caching is enabled, use the base class arg layout (positions 0-14) + # with block KV args at positions 11-14, plus Qwen3.5 custom args at 21-23. + if self.neuron_config.is_prefix_caching: + return self._get_model_outputs_prefix_caching( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden, + adapter_ids, + llava_args, + slot_mapping, + block_table, + full_context_lens, + computed_context_lens, + ) + + is_prefill = self._is_prefill(position_ids) + + seq_len = input_ids.shape[1] + batch_size = input_ids.shape[0] + + # Extract vision inputs and mRoPE position IDs from llava_args. + # llava_args layout: [vision_embeddings, vision_mask, mrope_position_ids (optional)] + if llava_args and len(llava_args) >= 2: + vision_embeddings = llava_args[0] + vision_mask = llava_args[1] + # mRoPE position IDs: (3, batch, seq_len) if provided + if len(llava_args) >= 3: + mrope_position_ids = llava_args[2] + else: + mrope_position_ids = None + elif is_prefill: + # Text-only CTE: generate dummy vision inputs matching compiled shape. + # The compiled CTE expects (BS, seq_len, hidden_size) and (BS, seq_len, 1). + # Use zeros for embeddings and seq_len-1 for mask (safe scatter target). + # NOTE: Do NOT use large sentinel values (e.g., 2**30) as fill_value -- + # they cause DGE out-of-bounds crashes in the Neuron runtime. + # Using seq_len-1 targets the last position (always a padding slot). + vision_embeddings = torch.zeros( + (batch_size, seq_len, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + (batch_size, seq_len, 1), + fill_value=seq_len - 1, + dtype=torch.int32, + ) + mrope_position_ids = None + else: + # TKG: empty tensors (no vision injection during decode) + vision_embeddings = torch.zeros((0,), dtype=torch.float32) + vision_mask = torch.zeros((0,), dtype=torch.int32) + mrope_position_ids = None + + # For CTE: mRoPE position IDs at slot 21 must be (3, batch, seq_len). + # If not provided (text-only), generate sequential IDs with T=H=W (identical axes). + # For TKG: slot 21 = torch.empty(0) → set_none_if_empty → None → fallback to 2D position_ids. + if is_prefill: + if mrope_position_ids is None: + mrope_position_ids = ( + torch.arange(0, seq_len, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + else: + mrope_position_ids = torch.zeros((0,), dtype=torch.int32) + + # Build the 14 empty tensors for positions 7-20 + # Position 21 = mrope_position_ids, 22 = vision_embeddings, 23 = vision_mask + empties = [torch.empty(0) for _ in range(14)] + + if self._is_prefill(position_ids): + # ----------------------------------------------------------- + # CTE batch splitting: handle ctx_batch_size < tkg_batch_size + # ----------------------------------------------------------- + # NxDI's model_wrapper.forward() splits batches by slicing ALL + # args along dim 0, which corrupts M-RoPE position_ids (shape + # [3, B, S] -- batch is dim 1, not dim 0). It also pads to + # max_batch_size (=tkg_batch_size) instead of ctx_batch_size. + # + # Fix: split batches here with correct dim handling, then pass + # each chunk with batch_size == ctx_batch_size so the wrapper + # takes the fast path (no internal splitting/padding). + # ----------------------------------------------------------- + ctx_bs = self.context_encoding_model.neuron_config.batch_size + output_logits = [] + + for cb in range(0, batch_size, ctx_bs): + cb_end = min(cb + ctx_bs, batch_size) + actual_chunk = cb_end - cb + + # Slice standard 2D args along dim 0 (batch dim) + chunk_input_ids = input_ids[cb:cb_end] + chunk_attn_mask = attention_mask[cb:cb_end] + chunk_pos_ids = position_ids[cb:cb_end] + chunk_seq_ids = seq_ids[cb:cb_end] + chunk_sampling = sampling_params[cb:cb_end] + chunk_prev_hidden = ( + prev_hidden[cb:cb_end] + if prev_hidden is not None + and hasattr(prev_hidden, "ndim") + and prev_hidden.ndim > 0 + and prev_hidden.shape[0] > 0 + else prev_hidden + ) + chunk_adapter_ids = ( + adapter_ids[cb:cb_end] + if adapter_ids is not None + and hasattr(adapter_ids, "ndim") + and adapter_ids.ndim > 0 + and adapter_ids.shape[0] > 0 + else adapter_ids + ) + + # M-RoPE: slice along dim 1 (batch dim for [3, B, S]) + if mrope_position_ids.ndim == 3: + chunk_mrope = mrope_position_ids[:, cb:cb_end, :] + else: + chunk_mrope = mrope_position_ids # empty tensor for TKG + + # Vision args: slice along dim 0 if batched + if vision_embeddings.ndim == 3: + chunk_vis_emb = vision_embeddings[cb:cb_end] + chunk_vis_mask = vision_mask[cb:cb_end] + else: + chunk_vis_emb = vision_embeddings + chunk_vis_mask = vision_mask + + # Pad if chunk is smaller than ctx_batch_size + if actual_chunk < ctx_bs: + pad_n = ctx_bs - actual_chunk + chunk_input_ids = torch.cat( + [chunk_input_ids, chunk_input_ids[:1].expand(pad_n, -1)], dim=0 + ) + chunk_attn_mask = torch.cat( + [chunk_attn_mask, chunk_attn_mask[:1].expand(pad_n, -1)], dim=0 + ) + chunk_pos_ids = torch.cat( + [chunk_pos_ids, chunk_pos_ids[:1].expand(pad_n, -1)], dim=0 + ) + # Pad seq_ids with unused IDs + pad_seq = torch.arange( + batch_size, batch_size + pad_n, dtype=chunk_seq_ids.dtype + ) + chunk_seq_ids = torch.cat([chunk_seq_ids, pad_seq], dim=0) + chunk_sampling = torch.cat( + [chunk_sampling, chunk_sampling[:1].expand(pad_n, -1)], dim=0 + ) + if ( + chunk_prev_hidden is not None + and hasattr(chunk_prev_hidden, "ndim") + and chunk_prev_hidden.ndim > 0 + and chunk_prev_hidden.shape[0] > 0 + ): + chunk_prev_hidden = torch.cat( + [ + chunk_prev_hidden, + chunk_prev_hidden[:1].expand(pad_n, -1), + ], + dim=0, + ) + if ( + chunk_adapter_ids is not None + and hasattr(chunk_adapter_ids, "ndim") + and chunk_adapter_ids.ndim > 0 + and chunk_adapter_ids.shape[0] > 0 + ): + chunk_adapter_ids = torch.cat( + [ + chunk_adapter_ids, + chunk_adapter_ids[:1].expand(pad_n, -1), + ], + dim=0, + ) + if chunk_mrope.ndim == 3: + chunk_mrope = torch.cat( + [chunk_mrope, chunk_mrope[:, :1, :].expand(-1, pad_n, -1)], + dim=1, + ) + if chunk_vis_emb.ndim == 3: + chunk_vis_emb = torch.cat( + [ + chunk_vis_emb, + torch.zeros( + (pad_n,) + chunk_vis_emb.shape[1:], + dtype=chunk_vis_emb.dtype, + ), + ], + dim=0, + ) + chunk_vis_mask = torch.cat( + [ + chunk_vis_mask, + torch.full( + (pad_n,) + chunk_vis_mask.shape[1:], + fill_value=seq_len - 1, + dtype=chunk_vis_mask.dtype, + ), + ], + dim=0, + ) + + chunk_out = self.context_encoding_model( + chunk_input_ids, # 0 + chunk_attn_mask, # 1 + chunk_pos_ids, # 2 + chunk_seq_ids, # 3 + chunk_sampling, # 4 + chunk_prev_hidden, # 5 + chunk_adapter_ids, # 6 + *empties, # 7-20 + chunk_mrope, # 21 + chunk_vis_emb, # 22 + chunk_vis_mask, # 23 + ) + # Slice off padding from output logits + if actual_chunk < ctx_bs: + chunk_out = chunk_out[:actual_chunk] + output_logits.append(chunk_out) + + outputs = ( + torch.cat(output_logits, dim=0) + if len(output_logits) > 1 + else output_logits[0] + ) + self.kv_cache_populated = True + is_run_on_neuron = self.context_encoding_model.is_neuron() + else: + # ----------------------------------------------------------- + # TKG batch padding: ensure batch == compiled batch_size so + # ModelWrapper.forward() takes the fast path (_forward(*args)) + # instead of _forward_with_pad() which drops custom args. + # ----------------------------------------------------------- + tkg_bs = self.token_generation_model.neuron_config.batch_size + actual_bs = input_ids.shape[0] + + if actual_bs < tkg_bs: + pad_n = tkg_bs - actual_bs + + # Pad 2D tensors along dim 0 + input_ids = torch.cat( + [input_ids, input_ids[:1].expand(pad_n, -1)], dim=0 + ) + attention_mask = torch.cat( + [attention_mask, attention_mask[:1].expand(pad_n, -1)], dim=0 + ) + position_ids = torch.cat( + [position_ids, position_ids[:1].expand(pad_n, -1)], dim=0 + ) + # Pad seq_ids with valid in-range IDs (repeat last valid seq_id) + # Using out-of-range IDs causes OOB memory access in KV cache + pad_seq = seq_ids[-1:].expand(pad_n) + seq_ids = torch.cat([seq_ids, pad_seq], dim=0) + sampling_params = torch.cat( + [sampling_params, sampling_params[:1].expand(pad_n, -1)], dim=0 + ) + # Pad prev_hidden if it has batch dimension + if ( + prev_hidden is not None + and hasattr(prev_hidden, "ndim") + and prev_hidden.ndim > 0 + and prev_hidden.shape[0] > 0 + ): + prev_hidden = torch.cat( + [prev_hidden, prev_hidden[:1].expand(pad_n, -1)], dim=0 + ) + # Pad adapter_ids if it has batch dimension + if ( + adapter_ids is not None + and hasattr(adapter_ids, "ndim") + and adapter_ids.ndim > 0 + and adapter_ids.shape[0] > 0 + ): + adapter_ids = torch.cat( + [adapter_ids, adapter_ids[:1].expand(pad_n, -1)], dim=0 + ) + # mrope_position_ids for TKG is torch.zeros((0,)) -- no padding needed + # vision_embeddings for TKG is torch.zeros((0,)) -- no padding needed + # vision_mask for TKG is torch.zeros((0,)) -- no padding needed + + outputs = self.token_generation_model( + input_ids, # 0 + attention_mask, # 1 + position_ids, # 2 + seq_ids, # 3 + sampling_params, # 4 + prev_hidden, # 5 + adapter_ids, # 6 + *empties, # 7-20 + mrope_position_ids, # 21 + vision_embeddings, # 22 + vision_mask, # 23 + ) + + # Slice off padding from output + if actual_bs < tkg_bs: + outputs = outputs[:actual_bs] + + is_run_on_neuron = self.token_generation_model.is_neuron() + + return outputs, is_run_on_neuron + + def _get_model_outputs_prefix_caching( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden, + adapter_ids, + llava_args, + slot_mapping, + block_table, + full_context_lens, + computed_context_lens, + ): + """Handle prefix caching using block KV layout. + + Uses the base class prefix caching arg layout (positions 0-14) which aligns + with the model's forward() parameter positions, plus Qwen3.5 custom args + (mRoPE, vision) at positions 21-23. + + Trace layout: + 0: input_ids, 1: attention_mask, 2: position_ids, 3: seq_ids, + 4: sampling_params, 5: prev_hidden (empty), 6: adapter_ids, + 7-10: empties (medusa slots), + 11: slot_mapping, 12: active_block_table, 13: num_queries, + 14: computed_context_lens, 15-20: empties, + 21: mrope_position_ids, 22: vision_embeddings, 23: vision_mask + """ + batch_size = input_ids.shape[0] + seq_len = input_ids.shape[1] + is_prefill = self._is_prefill(position_ids) + + # Compute num_queries from full_context_lens and computed_context_lens + num_queries = full_context_lens - computed_context_lens + + # Determine which model to use (CTE or TKG) + is_context_encoding = input_ids.shape[-1] > 1 and not position_ids.min().item() + base_model = ( + self.context_encoding_model + if is_context_encoding + else self.token_generation_model + ) + + # Extract vision inputs from llava_args + if llava_args and len(llava_args) >= 2: + vision_embeddings = llava_args[0] + vision_mask = llava_args[1] + mrope_position_ids = llava_args[2] if len(llava_args) >= 3 else None + elif is_prefill: + vision_embeddings = torch.zeros( + (batch_size, seq_len, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + (batch_size, seq_len, 1), + fill_value=seq_len - 1, + dtype=torch.int32, + ) + mrope_position_ids = None + else: + vision_embeddings = torch.zeros((0,), dtype=torch.float32) + vision_mask = torch.zeros((0,), dtype=torch.int32) + mrope_position_ids = None + + # For CTE: generate mRoPE position IDs if not provided + if is_prefill: + if mrope_position_ids is None: + mrope_position_ids = ( + torch.arange(0, seq_len, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + else: + mrope_position_ids = torch.zeros((0,), dtype=torch.int32) + + # Build empties for unused positions + empties_7_10 = [torch.empty(0) for _ in range(4)] # positions 7-10 + empties_15_20 = [torch.empty(0) for _ in range(6)] # positions 15-20 + + # Call the model with the prefix caching layout + # Call the model with the prefix caching layout + outputs = base_model( + input_ids, # 0 + attention_mask, # 1 + position_ids, # 2 + seq_ids, # 3 + sampling_params, # 4 + torch.empty(0), # 5: prev_hidden (unused in prefix caching) + adapter_ids, # 6 + *empties_7_10, # 7-10: medusa slots (empty) + slot_mapping, # 11: slot_mapping + block_table, # 12: active_block_table + num_queries, # 13: num_queries + computed_context_lens, # 14: computed_context_lens + *empties_15_20, # 15-20: empty + mrope_position_ids, # 21: rotary_position_id + vision_embeddings, # 22: vision_embeddings + vision_mask, # 23: vision_mask + ) + + if is_context_encoding: + self.kv_cache_populated = True + + is_run_on_neuron = base_model.is_neuron() + return outputs, is_run_on_neuron + + def get_compiler_args(self): + if self.compile_tag == CONTEXT_ENCODING_MODEL_TAG: + optimization_level = "-O1" + else: + optimization_level = "-O1" + + compiler_args = ( + "--enable-saturate-infinity " + "--enable-mixed-precision-accumulation " + f"--model-type transformer {optimization_level} " + "--auto-cast=none " + "--internal-enable-dge-levels vector_dynamic_offsets " + ) + return compiler_args diff --git a/contrib/models/Qwen3-Coder-Next/src/nki_deltanet.py b/contrib/models/Qwen3-Coder-Next/src/nki_deltanet.py new file mode 100644 index 00000000..f7a17d3c --- /dev/null +++ b/contrib/models/Qwen3-Coder-Next/src/nki_deltanet.py @@ -0,0 +1,349 @@ +"""NKI kernels for DeltaNet gated delta rule recurrent forward. + +NKI v3 (SDK 2.30, NKI 0.4.0). Processes a SINGLE (batch, head) pair per kernel call. +The caller loops over (B, H) in PyTorch and calls this kernel for each pair. + +Input layout: All inputs are 2D contiguous tensors (S, 128). +Each call processes one (batch, head) element's full sequence. + +k_dim = v_dim = 128, which matches SBUF tile partition dimension exactly. +g and beta are scalars per token, expanded to (S, 128) by the caller. + +Two kernel variants: + deltanet_recurrent_fwd -- returns output only (original) + deltanet_recurrent_fwd_state -- returns (output, final_state) for CTE->TKG carry-over + +Changes for NKI 0.4.0: + - tensor_copy no longer supports implicit partition broadcast (P=1 -> P=128) + - Use .broadcast_to() instead (PSUM -> SBUF(1,F), then broadcast_to((P,F))) +""" + +import nki +import nki.isa as nisa +import nki.language as nl + +# Partition dimension max (NeuronCore SBUF tile width) +P_MAX = 128 + + +@nki.jit +def deltanet_recurrent_fwd( + query: nl.ndarray, # (S, 128) float32 + key: nl.ndarray, # (S, 128) float32 + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 128) float32, log-decay broadcast to 128 + beta_in: nl.ndarray, # (S, 128) float32, write gate broadcast to 128 +) -> nl.ndarray: + """NKI kernel for DeltaNet recurrent forward -- single (batch, head). + + Iterates over sequence tokens with sequential_range. + State matrix (128 x 128) lives in SBUF. + + Args: + query: (S, 128) float32 + key: (S, 128) float32 + value: (S, 128) float32 + g_in: (S, 128) float32 + beta_in: (S, 128) float32 + + Returns: + output: (S, 128) float32 + """ + seq_len, dim = query.shape + + # Output tensor in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + + # Stride: for 2D (S, D), dim0 stride = D=128, dim1 stride = 1 + seq_stride = dim + + # Initialize recurrent state in SBUF: (128, 128) + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state, value=0.0) + + # Sequential loop over tokens (state-dependent) + for t in nl.sequential_range(seq_len): + tok_offset = t * seq_stride + + # ---- Load inputs for token t ---- + q_t = nl.ndarray((P_MAX, 1), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_t, + src=query.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + k_t = nl.ndarray((P_MAX, 1), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_t, + src=key.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + v_t = nl.ndarray((P_MAX, 1), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_t, + src=value.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + g_t = nl.ndarray((P_MAX, 1), dtype=g_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_t, + src=g_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + beta_t = nl.ndarray((P_MAX, 1), dtype=beta_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_t, + src=beta_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + # ---- Step 1: Decay state -- state = state * exp(g_t) ---- + exp_g = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_g, op=nl.exp, data=g_t, bias=None, scale=1.0) + + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_g, + engine=nisa.vector_engine, + ) + nisa.tensor_copy(dst=state, src=state_decayed) + + # ---- Step 2: Read memory -- kv_mem = state^T @ k_t ---- + kv_mem_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_mem_psum, stationary=state, moving=k_t) + kv_mem = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_mem, src=kv_mem_psum) + + # ---- Step 3: delta = (v_t - kv_mem) * beta_t ---- + v_sub = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_sub, data1=v_t, data2=kv_mem, op=nl.subtract) + + delta = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=delta, + data=v_sub, + op0=nl.multiply, + operand0=beta_t, + engine=nisa.vector_engine, + ) + + # ---- Step 4: state += outer(k_t, delta) ---- + # Broadcast multiply: outer[i,j] = k_t[i] * delta[j] + # 1) Transpose delta (128,1) -> (1,128) in PSUM + # 2) Use nc_matmul with ones to partition-broadcast (1,128) -> (128,128) + # 3) Multiply by k_t (128,1) which broadcasts across free dim + + # Transpose delta to get values along free dimension + delta_row_psum = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=delta_row_psum, data=delta) + + # NKI 0.4.0: tensor_copy requires matching partition dims. + # Copy (1,128) PSUM -> (1,128) SBUF first + delta_row_sbuf = nl.ndarray((1, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=delta_row_sbuf, src=delta_row_psum) + + # Partition broadcast via nc_matmul: ones(P_MAX, 1)^T @ delta_row(1, dim) + # = (1, P_MAX) @ (1, dim) -- but P must match... + # Actually: nc_matmul(stationary=(P,K), moving=(P,F)) -> (K,F) in PSUM + # We want (P_MAX, dim) result. So K=P_MAX, F=dim. + # stationary needs shape (P, P_MAX) and moving needs shape (P, dim). + # With P=1: stationary (1, P_MAX)=ones, moving (1, dim)=delta_row_sbuf + # Result: ones^T @ delta_row = (P_MAX, 1) @ (1, dim) = (P_MAX, dim) ✓ + ones_col = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=ones_col, value=1.0) + + delta_broadcast_psum = nl.ndarray( + (P_MAX, dim), dtype=nl.float32, buffer=nl.psum + ) + nisa.nc_matmul( + dst=delta_broadcast_psum, stationary=ones_col, moving=delta_row_sbuf + ) + + delta_broadcast = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=delta_broadcast, src=delta_broadcast_psum) + + # Element-wise multiply: outer[i,j] = delta_broadcast[i,j] * k_t[i,0] + # tensor_scalar broadcasts (P,1) k_t across all F columns + outer_prod = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=outer_prod, + data=delta_broadcast, + op0=nl.multiply, + operand0=k_t, + engine=nisa.vector_engine, + ) + + # Accumulate into state + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state, data2=outer_prod, op=nl.add) + nisa.tensor_copy(dst=state, src=state_new) + + # ---- Step 5: o_t = state^T @ q_t ---- + o_t_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=o_t_psum, stationary=state, moving=q_t) + o_t = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=o_t, src=o_t_psum) + + # ---- Store output for token t ---- + nisa.dma_copy( + dst=output.ap(pattern=[[1, dim]], offset=tok_offset), + src=o_t, + ) + + return output + + +@nki.jit +def deltanet_recurrent_fwd_state( + query: nl.ndarray, # (S, 128) float32 + key: nl.ndarray, # (S, 128) float32 + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 128) float32, log-decay broadcast to 128 + beta_in: nl.ndarray, # (S, 128) float32, write gate broadcast to 128 +): + """NKI kernel for DeltaNet recurrent forward with final state output. + + Same recurrence as deltanet_recurrent_fwd, but ALSO writes the final + recurrent state (128, 128) to an output HBM buffer. This enables + CTE -> TKG state carry-over. + + Returns: + output: (S, 128) float32 -- per-token output + final_state: (128, 128) float32 -- recurrent state after last token + """ + seq_len, dim = query.shape + + # Output tensors in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + final_state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + # Stride: for 2D (S, D), dim0 stride = D=128, dim1 stride = 1 + seq_stride = dim + + # Initialize recurrent state in SBUF: (128, 128) + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state, value=0.0) + + # Sequential loop over tokens (state-dependent) + for t in nl.sequential_range(seq_len): + tok_offset = t * seq_stride + + # ---- Load inputs for token t ---- + q_t = nl.ndarray((P_MAX, 1), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_t, + src=query.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + k_t = nl.ndarray((P_MAX, 1), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_t, + src=key.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + v_t = nl.ndarray((P_MAX, 1), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_t, + src=value.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + g_t = nl.ndarray((P_MAX, 1), dtype=g_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_t, + src=g_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + beta_t = nl.ndarray((P_MAX, 1), dtype=beta_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_t, + src=beta_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + # ---- Step 1: Decay state -- state = state * exp(g_t) ---- + exp_g = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_g, op=nl.exp, data=g_t, bias=None, scale=1.0) + + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_g, + engine=nisa.vector_engine, + ) + nisa.tensor_copy(dst=state, src=state_decayed) + + # ---- Step 2: Read memory -- kv_mem = state^T @ k_t ---- + kv_mem_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_mem_psum, stationary=state, moving=k_t) + kv_mem = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_mem, src=kv_mem_psum) + + # ---- Step 3: delta = (v_t - kv_mem) * beta_t ---- + v_sub = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_sub, data1=v_t, data2=kv_mem, op=nl.subtract) + + delta = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=delta, + data=v_sub, + op0=nl.multiply, + operand0=beta_t, + engine=nisa.vector_engine, + ) + + # ---- Step 4: state += outer(k_t, delta) ---- + # Broadcast multiply: outer[i,j] = k_t[i] * delta[j] + delta_row_psum = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=delta_row_psum, data=delta) + + # NKI 0.4.0: Use nc_matmul with ones for partition broadcast + delta_row_sbuf = nl.ndarray((1, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=delta_row_sbuf, src=delta_row_psum) + + ones_col = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=ones_col, value=1.0) + + delta_broadcast_psum = nl.ndarray( + (P_MAX, dim), dtype=nl.float32, buffer=nl.psum + ) + nisa.nc_matmul( + dst=delta_broadcast_psum, stationary=ones_col, moving=delta_row_sbuf + ) + + delta_broadcast = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=delta_broadcast, src=delta_broadcast_psum) + + outer_prod = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=outer_prod, + data=delta_broadcast, + op0=nl.multiply, + operand0=k_t, + engine=nisa.vector_engine, + ) + + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state, data2=outer_prod, op=nl.add) + nisa.tensor_copy(dst=state, src=state_new) + + # ---- Step 5: o_t = state^T @ q_t ---- + o_t_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=o_t_psum, stationary=state, moving=q_t) + o_t = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=o_t, src=o_t_psum) + + # ---- Store output for token t ---- + nisa.dma_copy( + dst=output.ap(pattern=[[1, dim]], offset=tok_offset), + src=o_t, + ) + + # ---- Write final state to HBM ---- + # state is (128, 128) in SBUF, copy to final_state in HBM + # Use dma_copy with full tile: P_MAX rows, dim cols + nisa.dma_copy(dst=final_state, src=state) + + return output, final_state diff --git a/contrib/models/Qwen3-Coder-Next/src/nki_flash_attn_d256_pipe.py b/contrib/models/Qwen3-Coder-Next/src/nki_flash_attn_d256_pipe.py new file mode 100644 index 00000000..23b5e075 --- /dev/null +++ b/contrib/models/Qwen3-Coder-Next/src/nki_flash_attn_d256_pipe.py @@ -0,0 +1,2258 @@ +""" +Flash attention for d=256 with deferred softmax, 3-stage software pipeline, GQA, causal mask. + +NKI Beta 2 API (`import nki`) with compiler-managed buffer placement. +Called per (batch, kv_head) pair with pre-sliced tensors (like deltanet pattern). +Integrated via Beta 2 PyTorchXLAKernel in the model for SPMD-traced-context. +3-stage software pipelining (QK+max | exp+transpose | PV+writeback). + +**BHSD layout for Q/K**: Q and K are accepted in standard BHSD (B, H, S, D) +layout. The kernel uses DMA transpose during load to convert to (D, S) layout +in SBUF. This avoids torch.permute() in the model's perform_prefill, which +creates XLA lazy tensors that the Beta 2 tracer cannot resolve in SPMD context. + +Architecture follows xpu-perf v4 (nki_flash_attn_bf16_pipe_opt) adapted for: + - head_dim=256 (2x128 QK tiling, split PV output) + - GQA (multiple Q heads per KV head) + - Causal masking (affine_select with pattern/offset) + +Layouts (per-call, single batch + single kv_head): + Q: (1, q_h_per_k_h, seq_q, 256) -- BHSD (seq on partition, d on free in HBM) + K: (1, 1, seq_k, 256) -- BHSD (seq on partition, d on free in HBM) + V: (1, 1, seq_v, 256) -- BHSD (seq on partition, d on free) + O: (1, q_h_per_k_h, seq_q, 256) -- BHSD (seq on partition, d on free) + +Internal SBUF layout after DMA transpose of Q/K: + Q_sb: (D_TILE=128, Q_GRP_SZ=128) -- d on partition, seq on free + K_sb: (D_TILE=128, K_TILE_SZ=512) -- d on partition, seq on free + V_sb: (V_TILE_SZ=128, D_HEAD=256) -- seq on partition, d on free + +Pipeline stages (for grp_i in main loop): + Stage 1 (grp_i+2): load_q + qk_and_max -- DMA + TensorEngine (MM1) + Stage 2 (grp_i+1): exp + dma_transpose -- VectorEngine + DMA + Stage 3 (grp_i): pv + write_back -- TensorEngine (MM2) + DMA + +Run on trn2: + source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13/bin/activate + python3 nki_flash_attn_d256_pipe.py # unit test + +Uses Beta 2 NKI API (`import nki`). Q/K accepted in BHSD layout; DMA transpose +during load converts to (D, S) SBUF layout needed for QK matmul. +""" + +import os + +os.environ.setdefault("NEURON_PLATFORM_TARGET_OVERRIDE", "trn2") + +import math +import numpy as np +import nki.isa as nisa +import nki.language as nl +import nki + +# ============================================================================ +# Constants +# ============================================================================ +D_HEAD = 256 +D_TILE = 128 # partition dim tile for d-tiling (256 = 2 x 128) +Q_GRP_SZ = 128 # Q group size = partition dim max +K_TILE_SZ = 512 # K tile size for MM1 (free dim of K in matmul) +V_TILE_SZ = 128 # V tile size for MM2 (partition dim of transposed P) +LARGE_TILE_SZ = 2048 # Large tile grouping +EXP_TILE_SZ = 512 # Exp tile for activation_reduce +PSUM_FMAX = 512 # PSUM free dimension max +FLOAT32_MIN = -3.4028235e38 + +# Partial RoPE constants (Qwen3.5: only 25% of head_dim=256 gets rotary) +ROPE_DIM = 64 # rope_dim = head_dim * partial_rotary_factor +ROPE_HALF = 32 # rope_dim // 2 — each half for rotate_half + + +# ============================================================================ +# ModularAllocator helpers (layout structure preserved, compiler-managed placement) +# ============================================================================ + + +def _align32(addr): + """Round up address to 32-byte alignment (required for DMA transpose).""" + return (addr + 31) // 32 * 32 + + +def _alloc_modular_1d(shape, dtype, block_dim, num_free_tiles, base_addr): + """Allocate 1D modular buffer list: block_dim entries, num_free_tiles physical. + + Elements at indices i and j share memory if i % num_free_tiles == j % num_free_tiles. + Returns (list_of_tensors, next_address). + """ + base_addr = _align32(base_addr) + tile_elems = 1 + for d in shape[1:]: + tile_elems *= d + dtype_size = 4 if dtype == nl.float32 else 2 + tile_bytes = _align32(tile_elems * dtype_size) + + tensors = [] + for i in range(block_dim): + addr = base_addr + (i % num_free_tiles) * tile_bytes + tensors.append(nl.ndarray(shape, dtype=dtype, buffer=nl.sbuf)) + next_addr = base_addr + num_free_tiles * tile_bytes + return tensors, next_addr + + +def _alloc_modular_2d( + shape, dtype, block_dim0, block_dim1, num_free0, num_free1, base_addr +): + """Allocate 2D modular buffer: [block_dim0][block_dim1], with modular addressing. + + Returns (nested_list, next_address). + """ + base_addr = _align32(base_addr) + tile_elems = 1 + for d in shape[1:]: + tile_elems *= d + dtype_size = 4 if dtype == nl.float32 else 2 + tile_bytes = _align32(tile_elems * dtype_size) + + tensors = [] + for i in range(block_dim0): + row = [] + for j in range(block_dim1): + idx = (i % num_free0) * num_free1 + (j % num_free1) + addr = base_addr + idx * tile_bytes + row.append(nl.ndarray(shape, dtype=dtype, buffer=nl.sbuf)) + tensors.append(row) + next_addr = base_addr + num_free0 * num_free1 * tile_bytes + return tensors, next_addr + + +def _alloc_modular_3d(shape, dtype, dims, n_free, base_addr): + """Allocate 3D modular buffer: [d0][d1][d2], with modular addressing. + + dims = (block_dim0, block_dim1, block_dim2) + n_free = (num_free0, num_free1, num_free2) + Returns (nested_list, next_address). + """ + base_addr = _align32(base_addr) + tile_elems = 1 + for d in shape[1:]: + tile_elems *= d + dtype_size = 4 if dtype == nl.float32 else 2 + tile_bytes = _align32(tile_elems * dtype_size) + + tensors = [] + for i in range(dims[0]): + layer = [] + for j in range(dims[1]): + row = [] + for k in range(dims[2]): + idx = ( + (i % n_free[0]) * n_free[1] * n_free[2] + + (j % n_free[1]) * n_free[2] + + (k % n_free[2]) + ) + addr = base_addr + idx * tile_bytes + row.append(nl.ndarray(shape, dtype=dtype, buffer=nl.sbuf)) + layer.append(row) + tensors.append(layer) + total_physical = n_free[0] * n_free[1] * n_free[2] + next_addr = base_addr + total_physical * tile_bytes + return tensors, next_addr + + +# ============================================================================ +# Pipeline stage functions +# ============================================================================ + + +def _pipe_load_q( + grp_i, + q_sb_lo, + q_sb_hi, + q_hbm, + d_tile, + seqlen_q, + batch_id, + q_head_idx, + n_heads, + d_head, + fuse_rope=False, + cos_lo_q_sb=None, + cos_hi_q_sb=None, + sin_q_sb=None, + rope_q_x1=None, + rope_q_x2=None, + rope_q_res1=None, + rope_q_res2=None, + cos_cache_hbm=None, + sin_cache_hbm=None, + rope_dim=64, +): + """Load Q group from BHSD HBM into SBUF with DMA transpose to (D, S) layout. + Optionally applies partial RoPE to the first rope_dim rows of q_sb_lo. + + Q_HBM layout: (1, H, S, D=256) -- BHSD + Q_SB layout: (D_TILE=128, Q_GRP_SZ=128) -- D on partition, S on free + DMA transpose: (S, D) in HBM -> (D, S) in SBUF + """ + q_start = grp_i * Q_GRP_SZ + num_q = min(seqlen_q - q_start, Q_GRP_SZ) + + # Compute flat offset into Q HBM: q[batch_id, q_head_idx, q_start, 0] + q_offset = ( + batch_id * n_heads * seqlen_q * d_head + + q_head_idx * seqlen_q * d_head + + q_start * d_head + ) + + # Lo half: D[0:128] + # Source HBM: BHSD layout, stride between S rows = d_head (not d_tile!) + # ap() format: [(stride, size), ...] — dim0 stride must be d_head=256 + # Dest SBUF: (D_TILE=128, Q_GRP_SZ=128) — transposed + nisa.dma_transpose( + dst=q_sb_lo[grp_i].ap([[Q_GRP_SZ, d_tile], [1, 1], [1, 1], [1, num_q]]), + src=q_hbm.ap( + [[d_head, num_q], [1, 1], [1, 1], [1, d_tile]], + offset=q_offset, + ), + ) + # Hi half: D[128:256] — same stride, offset shifted by d_tile + nisa.dma_transpose( + dst=q_sb_hi[grp_i].ap([[Q_GRP_SZ, d_tile], [1, 1], [1, 1], [1, num_q]]), + src=q_hbm.ap( + [[d_head, num_q], [1, 1], [1, 1], [1, d_tile]], + offset=q_offset + d_tile, + ), + ) + + # Apply partial RoPE to first rope_dim rows of q_sb_lo + if fuse_rope: + _load_rope_cos_sin_q( + grp_i, + cos_lo_q_sb, + cos_hi_q_sb, + sin_q_sb, + cos_cache_hbm, + sin_cache_hbm, + seqlen_q, + rope_dim, + ) + _apply_rope_q_sbuf( + grp_i, + q_sb_lo, + cos_lo_q_sb, + cos_hi_q_sb, + sin_q_sb, + rope_q_x1, + rope_q_x2, + rope_q_res1, + rope_q_res2, + seqlen_q, + ) + + +def _pipe_qk_and_max( + grp_i, + q_sb_lo, + q_sb_hi, + k_sb_lo, + k_sb_hi, + mm1_masked, + mm1_partial_max, + mm1_psum, + mm1_copy_sb, + mm1_asel_sb, + seqlen_q, + seqlen_kv, + scale, + num_k_tiles, + num_large_tiles, + use_causal_mask, +): + """Compute QK^T (MM1) with d=256 tiling, scale, causal mask, and row-wise max. + + Always applies causal masking via the nki-library pattern: + 1. tensor_copy PSUM -> mm1_copy_sb (temp SBUF) + 2. affine_select with pattern/offset -> mm1_asel_sb + 3. tensor_scalar_reduce: scale + max -> mm1_masked + mm1_partial_max + """ + q_start = grp_i * Q_GRP_SZ + num_q = min(seqlen_q - q_start, Q_GRP_SZ) + num_k_tiles_per_large = LARGE_TILE_SZ // K_TILE_SZ # 4 + + # Initialize partial max to -inf + nisa.memset(mm1_partial_max[grp_i][...], value=FLOAT32_MIN) + + # Initialize mm1_masked to -inf so that causally-skipped K tiles + # produce exp(-inf) = 0 in the softmax (no contribution). + for lt_idx in range(num_large_tiles): + nisa.memset(mm1_masked[grp_i][lt_idx][...], value=FLOAT32_MIN) + + for large_tile_idx in range(num_large_tiles): + for k_tile_local in range(num_k_tiles_per_large): + k_tile_idx = large_tile_idx * num_k_tiles_per_large + k_tile_local + if k_tile_idx >= num_k_tiles: + continue + + k_start = k_tile_idx * K_TILE_SZ + num_k = min(seqlen_kv - k_start, K_TILE_SZ) + if num_k <= 0: + continue + + # Causal skip: entire Q group before this K tile + q_last = q_start + num_q - 1 + if q_last < k_start: + continue + + # MM1: QK = Q_lo^T @ K_lo + Q_hi^T @ K_hi + psum_tile = mm1_psum[grp_i][large_tile_idx][k_tile_local] + + # First half: d[0:128] + nisa.nc_matmul( + psum_tile[:num_q, :num_k], + q_sb_lo[grp_i][:D_TILE, :num_q], + k_sb_lo[k_tile_idx][:D_TILE, :num_k], + ) + # Second half: d[128:256] — accumulates into same PSUM + nisa.nc_matmul( + psum_tile[:num_q, :num_k], + q_sb_hi[grp_i][:D_TILE, :num_q], + k_sb_hi[k_tile_idx][:D_TILE, :num_k], + ) + + # Copy PSUM -> temp SBUF (unscaled) + nisa.tensor_copy( + mm1_copy_sb[:num_q, :num_k], + psum_tile[:num_q, :num_k], + ) + + # Causal mask via affine_select (nki-library pattern) + # val = (-1)*p + (1)*f + offset >= 0 means f <= p + offset + # For causal: keep when k_pos <= q_pos, i.e., (k_start+f) <= (q_start+p) + # i.e., f <= p + (q_start - k_start) + # So offset = q_start - k_start + nisa.affine_select( + dst=mm1_asel_sb[:num_q, :num_k], + pattern=[[-1, num_k]], + offset=q_start - k_start, + channel_multiplier=1, + cmp_op=nl.greater_equal, + on_true_tile=mm1_copy_sb[:num_q, :num_k], + on_false_value=FLOAT32_MIN, + ) + + # Scale + max extraction + nisa.tensor_scalar_reduce( + mm1_masked[grp_i][large_tile_idx][ + :num_q, nl.ds(k_tile_local * K_TILE_SZ, num_k) + ], + data=mm1_asel_sb[:num_q, :num_k], + op0=nl.multiply, + operand0=scale, + reduce_op=nl.maximum, + reduce_res=mm1_partial_max[grp_i][:num_q, k_tile_idx], + ) + + +def _pipe_update_max( + grp_i, mm1_partial_max, mm1_section_max, mm1_running_max, num_k_tiles, seqlen_q +): + """Compute section max from partial maxes, store as -max (negated).""" + q_start = grp_i * Q_GRP_SZ + num_q = min(seqlen_q - q_start, Q_GRP_SZ) + + # Section max with negate=True (stores -max for use as bias in exp) + nisa.tensor_reduce( + mm1_section_max[grp_i][:num_q, 0], + nl.maximum, + mm1_partial_max[grp_i][:num_q, :num_k_tiles], + 1, + negate=True, + ) + + # For single-section: running_max = section_max + nisa.tensor_copy(mm1_running_max[:num_q, grp_i], mm1_section_max[grp_i][:num_q, 0]) + + +def _pipe_exp( + grp_i, + mm1_masked, + mm1_running_max, + exp_sb, + exp_partial_sum, + exp_tp_sb, + seqlen_q, + seqlen_kv, + num_large_tiles, + num_k_tiles, +): + """Compute exp(S - max), partial sums, and DMA transpose for MM2.""" + q_start = grp_i * Q_GRP_SZ + num_q = min(seqlen_q - q_start, Q_GRP_SZ) + num_exp_per_large = LARGE_TILE_SZ // EXP_TILE_SZ # 4 + + nisa.memset(exp_partial_sum[grp_i][...], value=0.0) + + for large_tile_idx in range(num_large_tiles): + for exp_tile_idx in range(num_exp_per_large): + kv_start = large_tile_idx * LARGE_TILE_SZ + exp_tile_idx * EXP_TILE_SZ + num_kv = min(seqlen_kv - kv_start, EXP_TILE_SZ) + if num_kv <= 0: + continue + + # activation_reduce: dst = exp(1.0*data + bias), reduce_res = sum(dst) + # bias = mm1_running_max (which is -max) + nisa.activation_reduce( + exp_sb[grp_i][large_tile_idx][ + :num_q, nl.ds(exp_tile_idx * EXP_TILE_SZ, num_kv) + ], + op=nl.exp, + data=mm1_masked[grp_i][large_tile_idx][ + :num_q, nl.ds(exp_tile_idx * EXP_TILE_SZ, num_kv) + ], + reduce_op=nl.add, + reduce_res=exp_partial_sum[grp_i][ + :num_q, + large_tile_idx * num_exp_per_large + exp_tile_idx, + ], + bias=mm1_running_max[:num_q, grp_i], + ) + + # DMA transpose: exp_sb[Q=128, KV=512] -> exp_tp_sb[KV=128, Q=512] + num_kv_outer = num_kv // V_TILE_SZ + num_kv_inner = num_kv % V_TILE_SZ + + if num_kv_outer >= 1: + nisa.dma_transpose( + dst=exp_tp_sb[grp_i][large_tile_idx][exp_tile_idx].ap( + [ + [K_TILE_SZ, V_TILE_SZ], + [1, 1], + [V_TILE_SZ, num_kv_outer], + [1, num_q], + ] + ), + src=exp_sb[grp_i][large_tile_idx].ap( + [ + [LARGE_TILE_SZ, num_q], + [1, 1], + [V_TILE_SZ, num_kv_outer], + [1, V_TILE_SZ], + ], + offset=exp_tile_idx * K_TILE_SZ, + ), + ) + + if num_kv_inner > 0: + nisa.dma_transpose( + dst=exp_tp_sb[grp_i][large_tile_idx][exp_tile_idx].ap( + [ + [K_TILE_SZ, num_kv_inner], + [1, 1], + [V_TILE_SZ, 1], + [1, num_q], + ], + offset=num_kv_outer * V_TILE_SZ, + ), + src=exp_sb[grp_i][large_tile_idx].ap( + [ + [LARGE_TILE_SZ, num_q], + [1, 1], + [V_TILE_SZ, 1], + [1, num_kv_inner], + ], + offset=exp_tile_idx * K_TILE_SZ + num_kv_outer * V_TILE_SZ, + ), + ) + + +def _pipe_pv( + grp_i, + exp_tp_sb, + v_sb, + mm2_psum_lo, + mm2_psum_hi, + mm2_sb, + seqlen_q, + seqlen_kv, + num_large_tiles, + num_v_tiles, +): + """Compute P@V (MM2) with d=256 split into lo/hi halves.""" + q_start = grp_i * Q_GRP_SZ + num_q = min(seqlen_q - q_start, Q_GRP_SZ) + num_mm2_grps_per_large = LARGE_TILE_SZ // K_TILE_SZ # 4 + num_mm2_per_grp = K_TILE_SZ // V_TILE_SZ # 4 + + # Zero the output accumulator + nisa.memset(mm2_sb[grp_i][...], value=0.0) + + for large_tile_idx in range(num_large_tiles): + psum_tile_lo = mm2_psum_lo[grp_i][large_tile_idx] + psum_tile_hi = mm2_psum_hi[grp_i][large_tile_idx] + + for mm2_grp_i in range(num_mm2_grps_per_large): + exp_tp_tile = exp_tp_sb[grp_i][large_tile_idx][mm2_grp_i] + + for mm2_i in range(num_mm2_per_grp): + v_tile_idx = ( + large_tile_idx * num_mm2_grps_per_large * num_mm2_per_grp + + mm2_grp_i * num_mm2_per_grp + + mm2_i + ) + kv_start = v_tile_idx * V_TILE_SZ + num_kv = min(seqlen_kv - kv_start, V_TILE_SZ) + if num_kv <= 0 or v_tile_idx >= num_v_tiles: + continue + + # MM2 lo: exp_tp^T @ V[:, :128] = [Q_GRP, 128] + nisa.nc_matmul( + psum_tile_lo[:num_q, :D_TILE], + exp_tp_tile[:num_kv, nl.ds(mm2_i * V_TILE_SZ, num_q)], + v_sb[v_tile_idx][:num_kv, :D_TILE], + ) + # MM2 hi: exp_tp^T @ V[:, 128:256] = [Q_GRP, 128] + nisa.nc_matmul( + psum_tile_hi[:num_q, :D_TILE], + exp_tp_tile[:num_kv, nl.ds(mm2_i * V_TILE_SZ, num_q)], + v_sb[v_tile_idx][:num_kv, nl.ds(D_TILE, D_TILE)], + ) + + # Accumulate large tile results into SBUF + if large_tile_idx == 0: + nisa.tensor_copy( + mm2_sb[grp_i][:num_q, :D_TILE], + psum_tile_lo[:num_q, :D_TILE], + ) + nisa.tensor_copy( + mm2_sb[grp_i][:num_q, nl.ds(D_TILE, D_TILE)], + psum_tile_hi[:num_q, :D_TILE], + ) + else: + nisa.tensor_tensor( + mm2_sb[grp_i][:num_q, :D_TILE], + mm2_sb[grp_i][:num_q, :D_TILE], + psum_tile_lo[:num_q, :D_TILE], + nl.add, + ) + nisa.tensor_tensor( + mm2_sb[grp_i][:num_q, nl.ds(D_TILE, D_TILE)], + mm2_sb[grp_i][:num_q, nl.ds(D_TILE, D_TILE)], + psum_tile_hi[:num_q, :D_TILE], + nl.add, + ) + + +def _pipe_fused_qkmax_and_pv( + grp_i, + q_sb_lo, + q_sb_hi, + k_sb_lo, + k_sb_hi, + mm1_masked, + mm1_partial_max, + mm1_psum, + mm1_copy_sb, + mm1_asel_sb, + exp_tp_sb, + v_sb, + mm2_psum_lo, + mm2_psum_hi, + mm2_sb, + seqlen_q, + seqlen_kv, + scale, + num_k_tiles, + num_large_tiles, + num_v_tiles, + use_causal_mask, +): + """Fused: QK+max for grp_i+2, PV for grp_i (interleaved MM1+MM2).""" + qkmax_grp = grp_i + 2 + pv_grp = grp_i + + q_start_pv = pv_grp * Q_GRP_SZ + num_q_pv = min(seqlen_q - q_start_pv, Q_GRP_SZ) + q_start_qk = qkmax_grp * Q_GRP_SZ + num_q_qk = min(seqlen_q - q_start_qk, Q_GRP_SZ) + + num_k_tiles_per_large = LARGE_TILE_SZ // K_TILE_SZ # 4 + num_mm2_grps_per_large = LARGE_TILE_SZ // K_TILE_SZ # 4 + num_mm2_per_grp = K_TILE_SZ // V_TILE_SZ # 4 + + # Init partial max for QK grp + nisa.memset(mm1_partial_max[qkmax_grp][...], value=FLOAT32_MIN) + # Init mm1_masked for QK grp to -inf (causally-skipped K tiles → exp=0) + for lt_idx in range(num_large_tiles): + nisa.memset(mm1_masked[qkmax_grp][lt_idx][...], value=FLOAT32_MIN) + # Init MM2 accumulator for PV grp + nisa.memset(mm2_sb[pv_grp][...], value=0.0) + + for large_tile_idx in range(num_large_tiles): + # --- PV for pv_grp --- + psum_tile_pv_lo = mm2_psum_lo[pv_grp][large_tile_idx] + psum_tile_pv_hi = mm2_psum_hi[pv_grp][large_tile_idx] + + for mm2_grp_i in range(num_mm2_grps_per_large): + exp_tp_tile = exp_tp_sb[pv_grp][large_tile_idx][mm2_grp_i] + + for mm2_i in range(num_mm2_per_grp): + v_tile_idx = ( + large_tile_idx * num_mm2_grps_per_large * num_mm2_per_grp + + mm2_grp_i * num_mm2_per_grp + + mm2_i + ) + kv_start = v_tile_idx * V_TILE_SZ + num_kv = min(seqlen_kv - kv_start, V_TILE_SZ) + if num_kv <= 0 or v_tile_idx >= num_v_tiles: + continue + + nisa.nc_matmul( + psum_tile_pv_lo[:num_q_pv, :D_TILE], + exp_tp_tile[:num_kv, nl.ds(mm2_i * V_TILE_SZ, num_q_pv)], + v_sb[v_tile_idx][:num_kv, :D_TILE], + ) + nisa.nc_matmul( + psum_tile_pv_hi[:num_q_pv, :D_TILE], + exp_tp_tile[:num_kv, nl.ds(mm2_i * V_TILE_SZ, num_q_pv)], + v_sb[v_tile_idx][:num_kv, nl.ds(D_TILE, D_TILE)], + ) + + # Accumulate PV large tile + if large_tile_idx == 0: + nisa.tensor_copy( + mm2_sb[pv_grp][:num_q_pv, :D_TILE], + psum_tile_pv_lo[:num_q_pv, :D_TILE], + ) + nisa.tensor_copy( + mm2_sb[pv_grp][:num_q_pv, nl.ds(D_TILE, D_TILE)], + psum_tile_pv_hi[:num_q_pv, :D_TILE], + ) + else: + nisa.tensor_tensor( + mm2_sb[pv_grp][:num_q_pv, :D_TILE], + mm2_sb[pv_grp][:num_q_pv, :D_TILE], + psum_tile_pv_lo[:num_q_pv, :D_TILE], + nl.add, + ) + nisa.tensor_tensor( + mm2_sb[pv_grp][:num_q_pv, nl.ds(D_TILE, D_TILE)], + mm2_sb[pv_grp][:num_q_pv, nl.ds(D_TILE, D_TILE)], + psum_tile_pv_hi[:num_q_pv, :D_TILE], + nl.add, + ) + + # --- QK+max for qkmax_grp --- + for k_tile_local in range(num_k_tiles_per_large): + k_tile_idx = large_tile_idx * num_k_tiles_per_large + k_tile_local + if k_tile_idx >= num_k_tiles: + continue + + k_start = k_tile_idx * K_TILE_SZ + num_k = min(seqlen_kv - k_start, K_TILE_SZ) + if num_k <= 0: + continue + + q_last = q_start_qk + num_q_qk - 1 + if q_last < k_start: + continue + + psum_tile_qk = mm1_psum[qkmax_grp][large_tile_idx][k_tile_local] + + # d=256 tiled QK: two nc_matmul calls + nisa.nc_matmul( + psum_tile_qk[:num_q_qk, :num_k], + q_sb_lo[qkmax_grp][:D_TILE, :num_q_qk], + k_sb_lo[k_tile_idx][:D_TILE, :num_k], + ) + nisa.nc_matmul( + psum_tile_qk[:num_q_qk, :num_k], + q_sb_hi[qkmax_grp][:D_TILE, :num_q_qk], + k_sb_hi[k_tile_idx][:D_TILE, :num_k], + ) + + nisa.tensor_copy( + mm1_copy_sb[:num_q_qk, :num_k], + psum_tile_qk[:num_q_qk, :num_k], + ) + + nisa.affine_select( + dst=mm1_asel_sb[:num_q_qk, :num_k], + pattern=[[-1, num_k]], + offset=q_start_qk - k_start, + channel_multiplier=1, + cmp_op=nl.greater_equal, + on_true_tile=mm1_copy_sb[:num_q_qk, :num_k], + on_false_value=FLOAT32_MIN, + ) + + nisa.tensor_scalar_reduce( + mm1_masked[qkmax_grp][large_tile_idx][ + :num_q_qk, nl.ds(k_tile_local * K_TILE_SZ, num_k) + ], + data=mm1_asel_sb[:num_q_qk, :num_k], + op0=nl.multiply, + operand0=scale, + reduce_op=nl.maximum, + reduce_res=mm1_partial_max[qkmax_grp][:num_q_qk, k_tile_idx], + ) + + +def _pipe_write_back( + grp_i, + mm2_sb, + exp_partial_sum, + exp_sum_recip, + wb_exp_section_sum, + wb_zero_bias, + wb_o_bf16, + o_hbm, + seqlen_q, + num_exp_tiles, + batch_id, + q_head_idx, +): + """Write-back: normalize by 1/sum(exp), cast to bf16, DMA to HBM.""" + q_start = grp_i * Q_GRP_SZ + num_q = min(seqlen_q - q_start, Q_GRP_SZ) + + # Reduce partial exp sums to get total + nisa.tensor_reduce( + wb_exp_section_sum[grp_i][:num_q, 0], + nl.add, + exp_partial_sum[grp_i][:num_q, :num_exp_tiles], + axis=1, + ) + + # Reciprocal + nisa.reciprocal( + exp_sum_recip[grp_i][:num_q, 0], + wb_exp_section_sum[grp_i][:num_q, 0], + ) + + # Scale output and cast to bf16: copy(recip * mm2_sb + zero_bias) -> bf16 + nisa.activation( + wb_o_bf16[grp_i][:num_q, :D_HEAD], + nl.copy, + mm2_sb[grp_i][:num_q, :D_HEAD], + scale=exp_sum_recip[grp_i][:num_q, 0], + bias=wb_zero_bias[:num_q], + ) + + # DMA to HBM output + nisa.dma_copy( + dst=o_hbm[batch_id, q_head_idx, q_start : q_start + num_q, 0:D_HEAD], + src=wb_o_bf16[grp_i][:num_q, :D_HEAD], + ) + + +# ============================================================================ +# RoPE helper functions (partial RoPE fusion) +# ============================================================================ + + +def _apply_rope_q_sbuf( + grp_i, + q_sb_lo, + cos_lo_sb, + cos_hi_sb, + sin_sb, + rope_x1, + rope_x2, + rope_res1, + rope_res2, + seqlen_q, +): + """Apply partial RoPE to Q group in SBUF (transposed D,S layout). + + Only the first ROPE_DIM=64 partition rows of q_sb_lo get rotated. + Rows 64:128 (and all of q_sb_hi) are unchanged. + + In the (D, S) layout: + q_sb_lo[0:32, :] = X1 (first half of rope dims) + q_sb_lo[32:64, :] = X2 (second half of rope dims) + + RoPE formula: + result[0:32] = X1 * cos_lo - X2 * sin + result[32:64] = X2 * cos_hi + X1 * sin + + All tensor_tensor ops use operands at partition row 0 to satisfy + NCC_IBIR297 (both SB inputs must have same partition base). + Results are computed entirely in the rope workspace buffers, then + copied to q_sb_lo. + + Buffers (all ROPE_HALF=32 partition rows, partition start 0): + cos_lo_sb: (32, S) — cos[0:32] + cos_hi_sb: (32, S) — cos[32:64] + sin_sb: (32, S) — sin[0:32] (symmetric) + rope_x1: (32, S) — X1 copy + rope_x2: (32, S) — X2 copy + rope_res1: (32, S) — intermediate/result + rope_res2: (32, S) — intermediate + """ + q_start = grp_i * Q_GRP_SZ + num_q = min(seqlen_q - q_start, Q_GRP_SZ) + + # Save original X1 and X2 into partition-0 aligned buffers + nisa.tensor_copy( + dst=rope_x1[:ROPE_HALF, :num_q], + src=q_sb_lo[grp_i][:ROPE_HALF, :num_q], + ) + nisa.tensor_copy( + dst=rope_x2[:ROPE_HALF, :num_q], + src=q_sb_lo[grp_i][nl.ds(ROPE_HALF, ROPE_HALF), :num_q], + ) + + # --- First half: q[0:32] = X1 * cos_lo - X2 * sin --- + nisa.tensor_tensor( + dst=rope_res1[:ROPE_HALF, :num_q], + data1=rope_x1[:ROPE_HALF, :num_q], + data2=cos_lo_sb[:ROPE_HALF, :num_q], + op=nl.multiply, + ) + nisa.tensor_tensor( + dst=rope_res2[:ROPE_HALF, :num_q], + data1=rope_x2[:ROPE_HALF, :num_q], + data2=sin_sb[:ROPE_HALF, :num_q], + op=nl.multiply, + ) + nisa.tensor_tensor( + dst=rope_res1[:ROPE_HALF, :num_q], + data1=rope_res1[:ROPE_HALF, :num_q], + data2=rope_res2[:ROPE_HALF, :num_q], + op=nl.subtract, + ) + nisa.tensor_copy( + dst=q_sb_lo[grp_i][:ROPE_HALF, :num_q], + src=rope_res1[:ROPE_HALF, :num_q], + ) + + # --- Second half: q[32:64] = X2 * cos_hi + X1 * sin --- + nisa.tensor_tensor( + dst=rope_res1[:ROPE_HALF, :num_q], + data1=rope_x2[:ROPE_HALF, :num_q], + data2=cos_hi_sb[:ROPE_HALF, :num_q], + op=nl.multiply, + ) + nisa.tensor_tensor( + dst=rope_res2[:ROPE_HALF, :num_q], + data1=rope_x1[:ROPE_HALF, :num_q], + data2=sin_sb[:ROPE_HALF, :num_q], + op=nl.multiply, + ) + nisa.tensor_tensor( + dst=rope_res1[:ROPE_HALF, :num_q], + data1=rope_res1[:ROPE_HALF, :num_q], + data2=rope_res2[:ROPE_HALF, :num_q], + op=nl.add, + ) + nisa.tensor_copy( + dst=q_sb_lo[grp_i][nl.ds(ROPE_HALF, ROPE_HALF), :num_q], + src=rope_res1[:ROPE_HALF, :num_q], + ) + + +def _apply_rope_k_sbuf( + k_tile_idx, + k_sb_lo, + cos_lo_sb, + cos_hi_sb, + sin_sb, + rope_x1, + rope_x2, + rope_res1, + rope_res2, + seqlen_kv, +): + """Apply partial RoPE to a K tile in SBUF (transposed D,S layout). + Same algorithm as Q, adapted for K_TILE_SZ free dim. + """ + k_start = k_tile_idx * K_TILE_SZ + num_k = min(seqlen_kv - k_start, K_TILE_SZ) + + nisa.tensor_copy( + dst=rope_x1[:ROPE_HALF, :num_k], + src=k_sb_lo[k_tile_idx][:ROPE_HALF, :num_k], + ) + nisa.tensor_copy( + dst=rope_x2[:ROPE_HALF, :num_k], + src=k_sb_lo[k_tile_idx][nl.ds(ROPE_HALF, ROPE_HALF), :num_k], + ) + + # First half: k[0:32] = X1 * cos_lo - X2 * sin + nisa.tensor_tensor( + dst=rope_res1[:ROPE_HALF, :num_k], + data1=rope_x1[:ROPE_HALF, :num_k], + data2=cos_lo_sb[:ROPE_HALF, :num_k], + op=nl.multiply, + ) + nisa.tensor_tensor( + dst=rope_res2[:ROPE_HALF, :num_k], + data1=rope_x2[:ROPE_HALF, :num_k], + data2=sin_sb[:ROPE_HALF, :num_k], + op=nl.multiply, + ) + nisa.tensor_tensor( + dst=rope_res1[:ROPE_HALF, :num_k], + data1=rope_res1[:ROPE_HALF, :num_k], + data2=rope_res2[:ROPE_HALF, :num_k], + op=nl.subtract, + ) + nisa.tensor_copy( + dst=k_sb_lo[k_tile_idx][:ROPE_HALF, :num_k], + src=rope_res1[:ROPE_HALF, :num_k], + ) + + # Second half: k[32:64] = X2 * cos_hi + X1 * sin + nisa.tensor_tensor( + dst=rope_res1[:ROPE_HALF, :num_k], + data1=rope_x2[:ROPE_HALF, :num_k], + data2=cos_hi_sb[:ROPE_HALF, :num_k], + op=nl.multiply, + ) + nisa.tensor_tensor( + dst=rope_res2[:ROPE_HALF, :num_k], + data1=rope_x1[:ROPE_HALF, :num_k], + data2=sin_sb[:ROPE_HALF, :num_k], + op=nl.multiply, + ) + nisa.tensor_tensor( + dst=rope_res1[:ROPE_HALF, :num_k], + data1=rope_res1[:ROPE_HALF, :num_k], + data2=rope_res2[:ROPE_HALF, :num_k], + op=nl.add, + ) + nisa.tensor_copy( + dst=k_sb_lo[k_tile_idx][nl.ds(ROPE_HALF, ROPE_HALF), :num_k], + src=rope_res1[:ROPE_HALF, :num_k], + ) + + +def _load_rope_cos_sin_q( + grp_i, + cos_lo_sb, + cos_hi_sb, + sin_sb, + cos_cache_hbm, + sin_cache_hbm, + seqlen_q, + rope_dim, +): + """Load cos/sin for a Q group from HBM into SBUF with DMA transpose. + + HBM layout: cos_cache (S, rope_dim=64) — per batch (batch squeezed) + SBUF layout: split into lo/hi (ROPE_HALF=32, Q_GRP_SZ) each + + cos_lo_sb: rows 0:32 of cos (D[0:32]) + cos_hi_sb: rows 32:64 of cos (D[32:64]) + sin_sb: rows 0:32 of sin (only first half — symmetric) + """ + q_start = grp_i * Q_GRP_SZ + num_q = min(seqlen_q - q_start, Q_GRP_SZ) + rope_half = rope_dim // 2 + + # cos lo: DMA transpose first rope_half cols of cos → (32, S) + cos_offset = q_start * rope_dim + nisa.dma_transpose( + dst=cos_lo_sb.ap([[Q_GRP_SZ, rope_half], [1, 1], [1, 1], [1, num_q]]), + src=cos_cache_hbm.ap( + [[rope_dim, num_q], [1, 1], [1, 1], [1, rope_half]], + offset=cos_offset, + ), + ) + + # cos hi: DMA transpose second rope_half cols of cos → (32, S) + nisa.dma_transpose( + dst=cos_hi_sb.ap([[Q_GRP_SZ, rope_half], [1, 1], [1, 1], [1, num_q]]), + src=cos_cache_hbm.ap( + [[rope_dim, num_q], [1, 1], [1, 1], [1, rope_half]], + offset=cos_offset + rope_half, + ), + ) + + # sin: DMA transpose first rope_half cols → (32, S) + sin_offset = q_start * rope_dim + nisa.dma_transpose( + dst=sin_sb.ap([[Q_GRP_SZ, rope_half], [1, 1], [1, 1], [1, num_q]]), + src=sin_cache_hbm.ap( + [[rope_dim, num_q], [1, 1], [1, 1], [1, rope_half]], + offset=sin_offset, + ), + ) + + +def _load_rope_cos_sin_k( + k_tile_idx, + cos_lo_sb, + cos_hi_sb, + sin_sb, + cos_cache_hbm, + sin_cache_hbm, + seqlen_kv, + rope_dim, +): + """Load cos/sin for a K tile from HBM into SBUF with DMA transpose. + Same as Q version but for K_TILE_SZ=512 free dim. + """ + k_start = k_tile_idx * K_TILE_SZ + num_k = min(seqlen_kv - k_start, K_TILE_SZ) + rope_half = rope_dim // 2 + + cos_offset = k_start * rope_dim + nisa.dma_transpose( + dst=cos_lo_sb.ap([[K_TILE_SZ, rope_half], [1, 1], [1, 1], [1, num_k]]), + src=cos_cache_hbm.ap( + [[rope_dim, num_k], [1, 1], [1, 1], [1, rope_half]], + offset=cos_offset, + ), + ) + nisa.dma_transpose( + dst=cos_hi_sb.ap([[K_TILE_SZ, rope_half], [1, 1], [1, 1], [1, num_k]]), + src=cos_cache_hbm.ap( + [[rope_dim, num_k], [1, 1], [1, 1], [1, rope_half]], + offset=cos_offset + rope_half, + ), + ) + + sin_offset = k_start * rope_dim + nisa.dma_transpose( + dst=sin_sb.ap([[K_TILE_SZ, rope_half], [1, 1], [1, 1], [1, num_k]]), + src=sin_cache_hbm.ap( + [[rope_dim, num_k], [1, 1], [1, 1], [1, rope_half]], + offset=sin_offset, + ), + ) + + +# ============================================================================ +# Main kernel +# ============================================================================ + + +@nki.jit +def flash_attn_d256_pipe( + q, + k, + v, + cos_cache=None, + sin_cache=None, + use_causal_mask=True, + q_h_per_k_h=4, + n_kv_heads=1, + seqlen_q=512, + seqlen_kv=512, + rope_dim=64, +): + """ + Flash attention for head_dim=256, 3-stage software pipelined, with fused partial RoPE. + + Called per (batch, kv_head) pair with pre-sliced tensors (like deltanet). + The caller loops over (B, kv_heads) and passes single-element slices. + + Q and K are accepted in BHSD layout (standard PyTorch layout), and the + kernel transposes them internally via DMA during load. This avoids the + need for torch.permute() in the caller, which would create XLA tensors + incompatible with the Beta 2 tracer in SPMD context. + + When cos_cache and sin_cache are provided, the kernel applies partial RoPE + (first rope_dim=64 of 256 head dims) internally to Q and K after loading + them into SBUF. This bypasses the Beta 2 tracer None-args issue where + element-wise ops with model buffers (cos/sin caches) cause Q/K to resolve + as None during KLIR tracing. By passing cos/sin as separate HBM inputs + (which are NOT derived from element-wise ops), we avoid the issue entirely. + + Args: + q: (1, q_h_per_k_h, seq_q, 256) -- bfloat16, BHSD, PRE-RoPE Q heads for one KV head + k: (1, 1, seq_k, 256) -- bfloat16, BHSD, PRE-RoPE single KV head + v: (1, 1, seq_v, 256) -- bfloat16, BHSD, single KV head (no RoPE) + cos_cache: (seq, rope_dim=64) -- bfloat16, cos values (batch dim squeezed) + sin_cache: (seq, rope_dim=64) -- bfloat16, sin values (batch dim squeezed) + use_causal_mask: bool + q_h_per_k_h: Q heads per KV head (explicit, avoids .shape) + n_kv_heads: must be 1 (kernel processes one KV head at a time) + seqlen_q: sequence length for Q + seqlen_kv: sequence length for K/V + rope_dim: number of head dims that get rotary embedding (default 64) + + Returns: + o: (1, q_h_per_k_h, seq_q, 256) -- bfloat16, BHSD (post-RoPE attention output) + """ + d = D_HEAD + n_heads = q_h_per_k_h * n_kv_heads + bs = 1 + + scale = 1.0 / math.sqrt(d) + + # Fixed indices — caller pre-slices tensors per (batch, kv_head) + batch_id = 0 + kv_head_id = 0 + + # Output allocation + o = nl.ndarray((1, n_heads, seqlen_q, d), dtype=nl.bfloat16, buffer=nl.shared_hbm) + + num_grps = (seqlen_q + Q_GRP_SZ - 1) // Q_GRP_SZ + num_k_tiles = (seqlen_kv + K_TILE_SZ - 1) // K_TILE_SZ + num_v_tiles = (seqlen_kv + V_TILE_SZ - 1) // V_TILE_SZ + num_large_tiles = (seqlen_kv + LARGE_TILE_SZ - 1) // LARGE_TILE_SZ + num_exp_per_large = LARGE_TILE_SZ // EXP_TILE_SZ # 4 + num_exp_tiles = num_large_tiles * num_exp_per_large + + # ========================================================================= + # Buffer Allocation (ModularAllocator-style) + # ========================================================================= + sca = 0 # SBUF current address + + # K lo buffers: [128, K_TILE_SZ=512] x num_k_tiles (all loaded) + k_sb_lo, sca = _alloc_modular_1d( + (D_TILE, K_TILE_SZ), + nl.bfloat16, + block_dim=num_k_tiles, + num_free_tiles=num_k_tiles, + base_addr=sca, + ) + # K hi buffers: [128, K_TILE_SZ=512] x num_k_tiles (all loaded) + k_sb_hi, sca = _alloc_modular_1d( + (D_TILE, K_TILE_SZ), + nl.bfloat16, + block_dim=num_k_tiles, + num_free_tiles=num_k_tiles, + base_addr=sca, + ) + + # V buffers: [V_TILE_SZ=128, D_HEAD=256] x num_v_tiles (all loaded) + v_sb, sca = _alloc_modular_1d( + (V_TILE_SZ, D_HEAD), + nl.bfloat16, + block_dim=num_v_tiles, + num_free_tiles=num_v_tiles, + base_addr=sca, + ) + + # Q lo buffers: [128, Q_GRP_SZ=128] x num_grps (modular 2) + q_sb_lo, sca = _alloc_modular_1d( + (D_TILE, Q_GRP_SZ), + nl.bfloat16, + block_dim=num_grps, + num_free_tiles=2, + base_addr=sca, + ) + # Q hi buffers: [128, Q_GRP_SZ=128] x num_grps (modular 2) + q_sb_hi, sca = _alloc_modular_1d( + (D_TILE, Q_GRP_SZ), + nl.bfloat16, + block_dim=num_grps, + num_free_tiles=2, + base_addr=sca, + ) + + # Causal masking temp buffers (shared, reused per K tile) + # mm1_copy_sb: [Q_GRP_SZ, K_TILE_SZ=512] -- PSUM copy for masking + sca = _align32(sca) + mm1_copy_sb = nl.ndarray( + (Q_GRP_SZ, K_TILE_SZ), + dtype=nl.float32, + buffer=nl.sbuf, + ) + sca += K_TILE_SZ * 4 # f32 + + # mm1_asel_sb: [Q_GRP_SZ, K_TILE_SZ=512] -- affine_select output + sca = _align32(sca) + mm1_asel_sb = nl.ndarray( + (Q_GRP_SZ, K_TILE_SZ), + dtype=nl.float32, + buffer=nl.sbuf, + ) + sca += K_TILE_SZ * 4 # f32 + + # mm1_masked: [Q_GRP_SZ, LARGE_TILE_SZ=2048] x [num_grps, num_large_tiles] + mm1_masked, sca = _alloc_modular_2d( + (Q_GRP_SZ, LARGE_TILE_SZ), + nl.float32, + num_grps, + num_large_tiles, + 2, + num_large_tiles, + sca, + ) + + # mm1_partial_max: [Q_GRP_SZ, num_k_tiles] x num_grps (modular 2) + mm1_partial_max, sca = _alloc_modular_1d( + (Q_GRP_SZ, num_k_tiles), + nl.float32, + block_dim=num_grps, + num_free_tiles=2, + base_addr=sca, + ) + + # mm1_section_max: [Q_GRP_SZ, 1] x num_grps (modular 2) + mm1_section_max, sca = _alloc_modular_1d( + (Q_GRP_SZ, 1), + nl.float32, + block_dim=num_grps, + num_free_tiles=2, + base_addr=sca, + ) + + # mm1_running_max: [Q_GRP_SZ, num_grps] -- persistent + sca = _align32(sca) + mm1_running_max = nl.ndarray( + (Q_GRP_SZ, num_grps), + dtype=nl.float32, + buffer=nl.sbuf, + ) + sca += num_grps * 4 + + # exp_sb: [Q_GRP_SZ, LARGE_TILE_SZ] x [num_grps, num_large_tiles] + exp_sb, sca = _alloc_modular_2d( + (Q_GRP_SZ, LARGE_TILE_SZ), + nl.bfloat16, + num_grps, + num_large_tiles, + 1, + num_large_tiles, + sca, + ) + + # exp_partial_sum: [Q_GRP_SZ, num_exp_tiles] x num_grps (modular 2) + exp_partial_sum, sca = _alloc_modular_1d( + (Q_GRP_SZ, num_exp_tiles), + nl.float32, + block_dim=num_grps, + num_free_tiles=2, + base_addr=sca, + ) + + # exp_tp_sb: [V_TILE_SZ=128, K_TILE_SZ=512] x [grps, large, exp_per_large] + exp_tp_sb, sca = _alloc_modular_3d( + (V_TILE_SZ, K_TILE_SZ), + nl.bfloat16, + (num_grps, num_large_tiles, num_exp_per_large), + (2, num_large_tiles, num_exp_per_large), + sca, + ) + + # mm2_sb: [Q_GRP_SZ, D_HEAD=256] x num_grps (modular 2) + mm2_sb, sca = _alloc_modular_1d( + (Q_GRP_SZ, D_HEAD), + nl.float32, + block_dim=num_grps, + num_free_tiles=2, + base_addr=sca, + ) + + # exp_sum_recip: [Q_GRP_SZ, 1] x num_grps (modular 2) + exp_sum_recip, sca = _alloc_modular_1d( + (Q_GRP_SZ, 1), + nl.float32, + block_dim=num_grps, + num_free_tiles=2, + base_addr=sca, + ) + + # Write-back buffers + wb_exp_section_sum, sca = _alloc_modular_1d( + (Q_GRP_SZ, 1), + nl.float32, + block_dim=num_grps, + num_free_tiles=2, + base_addr=sca, + ) + sca = _align32(sca) + wb_zero_bias = nl.ndarray( + (Q_GRP_SZ, 1), + dtype=nl.float32, + buffer=nl.sbuf, + ) + sca += 1 * 4 + wb_o_bf16, sca = _alloc_modular_1d( + (Q_GRP_SZ, D_HEAD), + nl.bfloat16, + block_dim=num_grps, + num_free_tiles=2, + base_addr=sca, + ) + + # ========================================================================= + # RoPE SBUF buffers (only allocated when cos_cache/sin_cache provided) + # All ROPE_HALF=32 partition rows starting at 0 to satisfy NCC_IBIR297. + # ========================================================================= + fuse_rope = cos_cache is not None and sin_cache is not None + + if fuse_rope: + rope_half = rope_dim // 2 # 32 + + # Q RoPE buffers: all (ROPE_HALF=32, Q_GRP_SZ=128) bf16 + sca = _align32(sca) + cos_lo_q_sb = nl.ndarray( + (ROPE_HALF, Q_GRP_SZ), dtype=nl.bfloat16, buffer=nl.sbuf + ) + sca += Q_GRP_SZ * 2 + cos_hi_q_sb = nl.ndarray( + (ROPE_HALF, Q_GRP_SZ), dtype=nl.bfloat16, buffer=nl.sbuf + ) + sca += Q_GRP_SZ * 2 + sin_q_sb = nl.ndarray((ROPE_HALF, Q_GRP_SZ), dtype=nl.bfloat16, buffer=nl.sbuf) + sca += Q_GRP_SZ * 2 + rope_q_x1 = nl.ndarray((ROPE_HALF, Q_GRP_SZ), dtype=nl.bfloat16, buffer=nl.sbuf) + sca += Q_GRP_SZ * 2 + rope_q_x2 = nl.ndarray((ROPE_HALF, Q_GRP_SZ), dtype=nl.bfloat16, buffer=nl.sbuf) + sca += Q_GRP_SZ * 2 + rope_q_res1 = nl.ndarray( + (ROPE_HALF, Q_GRP_SZ), dtype=nl.bfloat16, buffer=nl.sbuf + ) + sca += Q_GRP_SZ * 2 + rope_q_res2 = nl.ndarray( + (ROPE_HALF, Q_GRP_SZ), dtype=nl.bfloat16, buffer=nl.sbuf + ) + sca += Q_GRP_SZ * 2 + + # K RoPE buffers: all (ROPE_HALF=32, K_TILE_SZ=512) bf16 + # Reused across K tiles (sequential processing) + sca = _align32(sca) + cos_lo_k_sb = nl.ndarray( + (ROPE_HALF, K_TILE_SZ), dtype=nl.bfloat16, buffer=nl.sbuf + ) + sca += K_TILE_SZ * 2 + cos_hi_k_sb = nl.ndarray( + (ROPE_HALF, K_TILE_SZ), dtype=nl.bfloat16, buffer=nl.sbuf + ) + sca += K_TILE_SZ * 2 + sin_k_sb = nl.ndarray((ROPE_HALF, K_TILE_SZ), dtype=nl.bfloat16, buffer=nl.sbuf) + sca += K_TILE_SZ * 2 + rope_k_x1 = nl.ndarray( + (ROPE_HALF, K_TILE_SZ), dtype=nl.bfloat16, buffer=nl.sbuf + ) + sca += K_TILE_SZ * 2 + rope_k_x2 = nl.ndarray( + (ROPE_HALF, K_TILE_SZ), dtype=nl.bfloat16, buffer=nl.sbuf + ) + sca += K_TILE_SZ * 2 + rope_k_res1 = nl.ndarray( + (ROPE_HALF, K_TILE_SZ), dtype=nl.bfloat16, buffer=nl.sbuf + ) + sca += K_TILE_SZ * 2 + rope_k_res2 = nl.ndarray( + (ROPE_HALF, K_TILE_SZ), dtype=nl.bfloat16, buffer=nl.sbuf + ) + sca += K_TILE_SZ * 2 + else: + cos_lo_q_sb = cos_hi_q_sb = sin_q_sb = None + rope_q_x1 = rope_q_x2 = rope_q_res1 = rope_q_res2 = None + cos_lo_k_sb = cos_hi_k_sb = sin_k_sb = None + rope_k_x1 = rope_k_x2 = rope_k_res1 = rope_k_res2 = None + + # ========================================================================= + # GQA outer loop: iterate over Q heads sharing this KV head + # ========================================================================= + for i_q_h in range(q_h_per_k_h): + q_head_idx = kv_head_id * q_h_per_k_h + i_q_h + + # PSUM allocations (compiler-managed placement) + # Allocated per-GQA-iteration to avoid NCC_ISCH715 accumulation conflicts + # MM1 PSUM: QK matmul results + mm1_psum = [] + for grp_idx in range(num_grps): + grp_row = [] + for lt_idx in range(num_large_tiles): + tile_row = [] + for kt_idx in range(4): + tile = nl.ndarray( + (Q_GRP_SZ, PSUM_FMAX), + dtype=nl.float32, + buffer=nl.psum, + ) + tile_row.append(tile) + grp_row.append(tile_row) + mm1_psum.append(grp_row) + + # MM2 PSUM lo: PV result for d[0:128] + mm2_psum_lo = [] + for grp_idx in range(num_grps): + grp_row = [] + for lt_idx in range(num_large_tiles): + tile = nl.ndarray( + (Q_GRP_SZ, D_TILE), + dtype=nl.float32, + buffer=nl.psum, + ) + grp_row.append(tile) + mm2_psum_lo.append(grp_row) + + # MM2 PSUM hi: PV result for d[128:256] + mm2_psum_hi = [] + for grp_idx in range(num_grps): + grp_row = [] + for lt_idx in range(num_large_tiles): + tile = nl.ndarray( + (Q_GRP_SZ, D_TILE), + dtype=nl.float32, + buffer=nl.psum, + ) + grp_row.append(tile) + mm2_psum_hi.append(grp_row) + + # Load K and V (shared across Q heads in same GQA group) + # K is BHSD: (1, 1, S, D=256). DMA transpose to (D=128, S=512) in SBUF. + # Flat offset: k[batch_id, kv_head_id, k_start, 0] + for k_idx in nl.affine_range(num_k_tiles): + k_start = k_idx * K_TILE_SZ + num_k = min(seqlen_kv - k_start, K_TILE_SZ) + k_offset = ( + batch_id * n_kv_heads * seqlen_kv * d + + kv_head_id * seqlen_kv * d + + k_start * d + ) + # Lo half: D[0:128], transpose (S=num_k, D=128) -> (D=128, S=num_k) + # ap() dim0 stride must be d=256 (full head dim), not D_TILE=128 + nisa.dma_transpose( + dst=k_sb_lo[k_idx].ap( + [[K_TILE_SZ, D_TILE], [1, 1], [1, 1], [1, num_k]] + ), + src=k.ap( + [[d, num_k], [1, 1], [1, 1], [1, D_TILE]], + offset=k_offset, + ), + ) + # Hi half: D[128:256], transpose (S=num_k, D=128) -> (D=128, S=num_k) + nisa.dma_transpose( + dst=k_sb_hi[k_idx].ap( + [[K_TILE_SZ, D_TILE], [1, 1], [1, 1], [1, num_k]] + ), + src=k.ap( + [[d, num_k], [1, 1], [1, 1], [1, D_TILE]], + offset=k_offset + D_TILE, + ), + ) + + for v_idx in nl.affine_range(num_v_tiles): + v_start = v_idx * V_TILE_SZ + num_v = min(seqlen_kv - v_start, V_TILE_SZ) + nisa.dma_copy( + dst=v_sb[v_idx][:num_v, :D_HEAD], + src=v[batch_id, kv_head_id, v_start : v_start + num_v, 0:D_HEAD], + ) + + # Apply RoPE to K tiles (sequential — shared cos/sin/tmp buffers) + if fuse_rope: + for k_idx in nl.sequential_range(num_k_tiles): + _load_rope_cos_sin_k( + k_idx, + cos_lo_k_sb, + cos_hi_k_sb, + sin_k_sb, + cos_cache, + sin_cache, + seqlen_kv, + rope_dim, + ) + _apply_rope_k_sbuf( + k_idx, + k_sb_lo, + cos_lo_k_sb, + cos_hi_k_sb, + sin_k_sb, + rope_k_x1, + rope_k_x2, + rope_k_res1, + rope_k_res2, + seqlen_kv, + ) + + # Zero the bias buffer once + nisa.memset(wb_zero_bias, value=0.0) + + # ===================================================================== + # Software Pipeline + # ===================================================================== + if num_grps <= 1: + # Single group -- no pipelining + _pipe_load_q( + 0, + q_sb_lo, + q_sb_hi, + q, + D_TILE, + seqlen_q, + batch_id, + q_head_idx, + n_heads, + d, + fuse_rope=fuse_rope, + cos_lo_q_sb=cos_lo_q_sb if fuse_rope else None, + cos_hi_q_sb=cos_hi_q_sb if fuse_rope else None, + sin_q_sb=sin_q_sb if fuse_rope else None, + rope_q_x1=rope_q_x1 if fuse_rope else None, + rope_q_x2=rope_q_x2 if fuse_rope else None, + rope_q_res1=rope_q_res1 if fuse_rope else None, + rope_q_res2=rope_q_res2 if fuse_rope else None, + cos_cache_hbm=cos_cache if fuse_rope else None, + sin_cache_hbm=sin_cache if fuse_rope else None, + rope_dim=rope_dim, + ) + _pipe_qk_and_max( + 0, + q_sb_lo, + q_sb_hi, + k_sb_lo, + k_sb_hi, + mm1_masked, + mm1_partial_max, + mm1_psum, + mm1_copy_sb, + mm1_asel_sb, + seqlen_q, + seqlen_kv, + scale, + num_k_tiles, + num_large_tiles, + use_causal_mask, + ) + _pipe_update_max( + 0, + mm1_partial_max, + mm1_section_max, + mm1_running_max, + num_k_tiles, + seqlen_q, + ) + _pipe_exp( + 0, + mm1_masked, + mm1_running_max, + exp_sb, + exp_partial_sum, + exp_tp_sb, + seqlen_q, + seqlen_kv, + num_large_tiles, + num_k_tiles, + ) + _pipe_pv( + 0, + exp_tp_sb, + v_sb, + mm2_psum_lo, + mm2_psum_hi, + mm2_sb, + seqlen_q, + seqlen_kv, + num_large_tiles, + num_v_tiles, + ) + _pipe_write_back( + 0, + mm2_sb, + exp_partial_sum, + exp_sum_recip, + wb_exp_section_sum, + wb_zero_bias, + wb_o_bf16, + o, + seqlen_q, + num_exp_tiles, + batch_id, + q_head_idx, + ) + + elif num_grps == 2: + # Two groups -- partial pipelining + _pipe_load_q( + 0, + q_sb_lo, + q_sb_hi, + q, + D_TILE, + seqlen_q, + batch_id, + q_head_idx, + n_heads, + d, + fuse_rope=fuse_rope, + cos_lo_q_sb=cos_lo_q_sb if fuse_rope else None, + cos_hi_q_sb=cos_hi_q_sb if fuse_rope else None, + sin_q_sb=sin_q_sb if fuse_rope else None, + rope_q_x1=rope_q_x1 if fuse_rope else None, + rope_q_x2=rope_q_x2 if fuse_rope else None, + rope_q_res1=rope_q_res1 if fuse_rope else None, + rope_q_res2=rope_q_res2 if fuse_rope else None, + cos_cache_hbm=cos_cache if fuse_rope else None, + sin_cache_hbm=sin_cache if fuse_rope else None, + rope_dim=rope_dim, + ) + _pipe_qk_and_max( + 0, + q_sb_lo, + q_sb_hi, + k_sb_lo, + k_sb_hi, + mm1_masked, + mm1_partial_max, + mm1_psum, + mm1_copy_sb, + mm1_asel_sb, + seqlen_q, + seqlen_kv, + scale, + num_k_tiles, + num_large_tiles, + use_causal_mask, + ) + _pipe_update_max( + 0, + mm1_partial_max, + mm1_section_max, + mm1_running_max, + num_k_tiles, + seqlen_q, + ) + _pipe_exp( + 0, + mm1_masked, + mm1_running_max, + exp_sb, + exp_partial_sum, + exp_tp_sb, + seqlen_q, + seqlen_kv, + num_large_tiles, + num_k_tiles, + ) + + _pipe_load_q( + 1, + q_sb_lo, + q_sb_hi, + q, + D_TILE, + seqlen_q, + batch_id, + q_head_idx, + n_heads, + d, + fuse_rope=fuse_rope, + cos_lo_q_sb=cos_lo_q_sb if fuse_rope else None, + cos_hi_q_sb=cos_hi_q_sb if fuse_rope else None, + sin_q_sb=sin_q_sb if fuse_rope else None, + rope_q_x1=rope_q_x1 if fuse_rope else None, + rope_q_x2=rope_q_x2 if fuse_rope else None, + rope_q_res1=rope_q_res1 if fuse_rope else None, + rope_q_res2=rope_q_res2 if fuse_rope else None, + cos_cache_hbm=cos_cache if fuse_rope else None, + sin_cache_hbm=sin_cache if fuse_rope else None, + rope_dim=rope_dim, + ) + _pipe_qk_and_max( + 1, + q_sb_lo, + q_sb_hi, + k_sb_lo, + k_sb_hi, + mm1_masked, + mm1_partial_max, + mm1_psum, + mm1_copy_sb, + mm1_asel_sb, + seqlen_q, + seqlen_kv, + scale, + num_k_tiles, + num_large_tiles, + use_causal_mask, + ) + _pipe_update_max( + 1, + mm1_partial_max, + mm1_section_max, + mm1_running_max, + num_k_tiles, + seqlen_q, + ) + + _pipe_pv( + 0, + exp_tp_sb, + v_sb, + mm2_psum_lo, + mm2_psum_hi, + mm2_sb, + seqlen_q, + seqlen_kv, + num_large_tiles, + num_v_tiles, + ) + _pipe_write_back( + 0, + mm2_sb, + exp_partial_sum, + exp_sum_recip, + wb_exp_section_sum, + wb_zero_bias, + wb_o_bf16, + o, + seqlen_q, + num_exp_tiles, + batch_id, + q_head_idx, + ) + + _pipe_exp( + 1, + mm1_masked, + mm1_running_max, + exp_sb, + exp_partial_sum, + exp_tp_sb, + seqlen_q, + seqlen_kv, + num_large_tiles, + num_k_tiles, + ) + _pipe_pv( + 1, + exp_tp_sb, + v_sb, + mm2_psum_lo, + mm2_psum_hi, + mm2_sb, + seqlen_q, + seqlen_kv, + num_large_tiles, + num_v_tiles, + ) + _pipe_write_back( + 1, + mm2_sb, + exp_partial_sum, + exp_sum_recip, + wb_exp_section_sum, + wb_zero_bias, + wb_o_bf16, + o, + seqlen_q, + num_exp_tiles, + batch_id, + q_head_idx, + ) + + else: + # Full 3-stage pipelining (num_grps >= 3) + + # Prologue: prime groups 0 and 1 + _pipe_load_q( + 0, + q_sb_lo, + q_sb_hi, + q, + D_TILE, + seqlen_q, + batch_id, + q_head_idx, + n_heads, + d, + fuse_rope=fuse_rope, + cos_lo_q_sb=cos_lo_q_sb if fuse_rope else None, + cos_hi_q_sb=cos_hi_q_sb if fuse_rope else None, + sin_q_sb=sin_q_sb if fuse_rope else None, + rope_q_x1=rope_q_x1 if fuse_rope else None, + rope_q_x2=rope_q_x2 if fuse_rope else None, + rope_q_res1=rope_q_res1 if fuse_rope else None, + rope_q_res2=rope_q_res2 if fuse_rope else None, + cos_cache_hbm=cos_cache if fuse_rope else None, + sin_cache_hbm=sin_cache if fuse_rope else None, + rope_dim=rope_dim, + ) + _pipe_qk_and_max( + 0, + q_sb_lo, + q_sb_hi, + k_sb_lo, + k_sb_hi, + mm1_masked, + mm1_partial_max, + mm1_psum, + mm1_copy_sb, + mm1_asel_sb, + seqlen_q, + seqlen_kv, + scale, + num_k_tiles, + num_large_tiles, + use_causal_mask, + ) + _pipe_update_max( + 0, + mm1_partial_max, + mm1_section_max, + mm1_running_max, + num_k_tiles, + seqlen_q, + ) + _pipe_exp( + 0, + mm1_masked, + mm1_running_max, + exp_sb, + exp_partial_sum, + exp_tp_sb, + seqlen_q, + seqlen_kv, + num_large_tiles, + num_k_tiles, + ) + + _pipe_load_q( + 1, + q_sb_lo, + q_sb_hi, + q, + D_TILE, + seqlen_q, + batch_id, + q_head_idx, + n_heads, + d, + fuse_rope=fuse_rope, + cos_lo_q_sb=cos_lo_q_sb if fuse_rope else None, + cos_hi_q_sb=cos_hi_q_sb if fuse_rope else None, + sin_q_sb=sin_q_sb if fuse_rope else None, + rope_q_x1=rope_q_x1 if fuse_rope else None, + rope_q_x2=rope_q_x2 if fuse_rope else None, + rope_q_res1=rope_q_res1 if fuse_rope else None, + rope_q_res2=rope_q_res2 if fuse_rope else None, + cos_cache_hbm=cos_cache if fuse_rope else None, + sin_cache_hbm=sin_cache if fuse_rope else None, + rope_dim=rope_dim, + ) + _pipe_qk_and_max( + 1, + q_sb_lo, + q_sb_hi, + k_sb_lo, + k_sb_hi, + mm1_masked, + mm1_partial_max, + mm1_psum, + mm1_copy_sb, + mm1_asel_sb, + seqlen_q, + seqlen_kv, + scale, + num_k_tiles, + num_large_tiles, + use_causal_mask, + ) + _pipe_update_max( + 1, + mm1_partial_max, + mm1_section_max, + mm1_running_max, + num_k_tiles, + seqlen_q, + ) + + # Main pipelined loop + for grp_i in range(0, num_grps - 2): + _pipe_load_q( + grp_i + 2, + q_sb_lo, + q_sb_hi, + q, + D_TILE, + seqlen_q, + batch_id, + q_head_idx, + n_heads, + d, + fuse_rope=fuse_rope, + cos_lo_q_sb=cos_lo_q_sb if fuse_rope else None, + cos_hi_q_sb=cos_hi_q_sb if fuse_rope else None, + sin_q_sb=sin_q_sb if fuse_rope else None, + rope_q_x1=rope_q_x1 if fuse_rope else None, + rope_q_x2=rope_q_x2 if fuse_rope else None, + rope_q_res1=rope_q_res1 if fuse_rope else None, + rope_q_res2=rope_q_res2 if fuse_rope else None, + cos_cache_hbm=cos_cache if fuse_rope else None, + sin_cache_hbm=sin_cache if fuse_rope else None, + rope_dim=rope_dim, + ) + _pipe_exp( + grp_i + 1, + mm1_masked, + mm1_running_max, + exp_sb, + exp_partial_sum, + exp_tp_sb, + seqlen_q, + seqlen_kv, + num_large_tiles, + num_k_tiles, + ) + _pipe_fused_qkmax_and_pv( + grp_i, + q_sb_lo, + q_sb_hi, + k_sb_lo, + k_sb_hi, + mm1_masked, + mm1_partial_max, + mm1_psum, + mm1_copy_sb, + mm1_asel_sb, + exp_tp_sb, + v_sb, + mm2_psum_lo, + mm2_psum_hi, + mm2_sb, + seqlen_q, + seqlen_kv, + scale, + num_k_tiles, + num_large_tiles, + num_v_tiles, + use_causal_mask, + ) + _pipe_write_back( + grp_i, + mm2_sb, + exp_partial_sum, + exp_sum_recip, + wb_exp_section_sum, + wb_zero_bias, + wb_o_bf16, + o, + seqlen_q, + num_exp_tiles, + batch_id, + q_head_idx, + ) + _pipe_update_max( + grp_i + 2, + mm1_partial_max, + mm1_section_max, + mm1_running_max, + num_k_tiles, + seqlen_q, + ) + + # Epilogue: drain last 2 groups + _pipe_pv( + num_grps - 2, + exp_tp_sb, + v_sb, + mm2_psum_lo, + mm2_psum_hi, + mm2_sb, + seqlen_q, + seqlen_kv, + num_large_tiles, + num_v_tiles, + ) + _pipe_write_back( + num_grps - 2, + mm2_sb, + exp_partial_sum, + exp_sum_recip, + wb_exp_section_sum, + wb_zero_bias, + wb_o_bf16, + o, + seqlen_q, + num_exp_tiles, + batch_id, + q_head_idx, + ) + + _pipe_exp( + num_grps - 1, + mm1_masked, + mm1_running_max, + exp_sb, + exp_partial_sum, + exp_tp_sb, + seqlen_q, + seqlen_kv, + num_large_tiles, + num_k_tiles, + ) + _pipe_pv( + num_grps - 1, + exp_tp_sb, + v_sb, + mm2_psum_lo, + mm2_psum_hi, + mm2_sb, + seqlen_q, + seqlen_kv, + num_large_tiles, + num_v_tiles, + ) + _pipe_write_back( + num_grps - 1, + mm2_sb, + exp_partial_sum, + exp_sum_recip, + wb_exp_section_sum, + wb_zero_bias, + wb_o_bf16, + o, + seqlen_q, + num_exp_tiles, + batch_id, + q_head_idx, + ) + + return o + + +# ============================================================================ +# Unit test +# ============================================================================ +if __name__ == "__main__": + import torch + import torch.nn.functional as F + import time + + def reference_causal_attention(q, k, v): + """CPU reference: q(b,h,sq,d), k(b,h,sk,d), v(b,h,sk,d) -> (b,h,sq,d) + + All inputs in BHSD layout. + """ + d = q.shape[3] + q_t = q.float() + k_t = k.float() + v_t = v.float() + scale = 1.0 / (d**0.5) + attn = q_t @ k_t.transpose(-2, -1) * scale + mask = torch.triu( + torch.ones(q_t.shape[2], k_t.shape[2], dtype=torch.bool), diagonal=1 + ) + attn = attn.masked_fill(mask, float("-inf")) + attn = F.softmax(attn, dim=-1) + return attn @ v_t + + def rotate_half(x): + """Standard rotate_half for RoPE.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_partial_rope(q, k, cos, sin, rope_dim=64): + """Apply partial RoPE to Q and K (only first rope_dim dimensions). + + q: (B, H, S, D=256) - BHSD + k: (B, Hkv, S, D=256) - BHSD + cos: (S, rope_dim=64) + sin: (S, rope_dim=64) + Returns: post-RoPE q, k with same shape + """ + # Expand cos/sin to broadcast: (1, 1, S, rope_dim) + cos_exp = cos.unsqueeze(0).unsqueeze(0) + sin_exp = sin.unsqueeze(0).unsqueeze(0) + + # Split rope/pass-through portions + q_rope = q[..., :rope_dim] + q_pass = q[..., rope_dim:] + k_rope = k[..., :rope_dim] + k_pass = k[..., rope_dim:] + + # Apply RoPE + q_rope = q_rope * cos_exp + rotate_half(q_rope) * sin_exp + k_rope = k_rope * cos_exp + rotate_half(k_rope) * sin_exp + + # Reassemble + q_out = torch.cat([q_rope, q_pass], dim=-1) + k_out = torch.cat([k_rope, k_pass], dim=-1) + return q_out, k_out + + import torch_xla.core.xla_model as xm + + device = xm.xla_device() + + # ==================================================================== + # Part 1: Original tests (post-RoPE inputs, no cos/sin — backward compat) + # ==================================================================== + print("=" * 70) + print("PART 1: Post-RoPE inputs (backward compatible, no fused RoPE)") + print("=" * 70) + + tests_basic = [ + {"seq": 512, "bs": 1, "heads": 1, "kv_heads": 1, "label": "seq=512 1:1"}, + {"seq": 1024, "bs": 1, "heads": 1, "kv_heads": 1, "label": "seq=1024 1:1"}, + {"seq": 512, "bs": 1, "heads": 4, "kv_heads": 1, "label": "seq=512 GQA 4:1"}, + {"seq": 1024, "bs": 1, "heads": 4, "kv_heads": 1, "label": "seq=1024 GQA 4:1"}, + ] + + for t in tests_basic: + seq_len = t["seq"] + bs = t["bs"] + heads = t["heads"] + kv_heads = t["kv_heads"] + d = 256 + print(f"\n=== Testing: {t['label']} ===") + torch.manual_seed(42) + q = torch.randn(bs, heads, seq_len, d, dtype=torch.bfloat16) + k = torch.randn(bs, kv_heads, seq_len, d, dtype=torch.bfloat16) + v = torch.randn(bs, kv_heads, seq_len, d, dtype=torch.bfloat16) + + ref_parts = [] + for h_idx in range(heads): + kv_idx = h_idx // (heads // kv_heads) + ref_h = reference_causal_attention( + q[:, h_idx : h_idx + 1], + k[:, kv_idx : kv_idx + 1], + v[:, kv_idx : kv_idx + 1], + ) + ref_parts.append(ref_h) + ref = torch.cat(ref_parts, dim=1) + + q_dev = q.to(device) + k_dev = k.to(device) + v_dev = v.to(device) + t0 = time.time() + q_h_per_kv = heads // kv_heads + out_parts = [] + for b in range(bs): + for kv_h in range(kv_heads): + q_slice = q_dev[ + b : b + 1, kv_h * q_h_per_kv : (kv_h + 1) * q_h_per_kv, :, : + ] + k_slice = k_dev[b : b + 1, kv_h : kv_h + 1, :, :] + v_slice = v_dev[b : b + 1, kv_h : kv_h + 1, :, :] + o_part = flash_attn_d256_pipe( + q_slice, + k_slice, + v_slice, + use_causal_mask=True, + q_h_per_k_h=q_h_per_kv, + n_kv_heads=1, + seqlen_q=seq_len, + seqlen_kv=seq_len, + ) + out_parts.append(o_part) + out = torch.cat(out_parts, dim=1) + xm.mark_step() + out_cpu = out.cpu().float() + t1 = time.time() + + cos_sim = F.cosine_similarity( + ref.reshape(-1).unsqueeze(0), out_cpu.reshape(-1).unsqueeze(0) + ).item() + maxd = (ref - out_cpu).abs().max().item() + print(f" Time: {t1 - t0:.1f}s (includes compile)") + print(f" Cosine sim: {cos_sim:.6f}") + print(f" Max diff: {maxd:.6f}") + print(f" {'PASS' if cos_sim > 0.999 else 'FAIL'}") + + # ==================================================================== + # Part 2: Fused RoPE tests (pre-RoPE inputs + cos/sin caches) + # ==================================================================== + print("\n" + "=" * 70) + print("PART 2: Fused RoPE (pre-RoPE inputs + cos/sin caches)") + print("=" * 70) + + rope_dim = 64 + + tests_rope = [ + {"seq": 512, "bs": 1, "heads": 1, "kv_heads": 1, "label": "ROPE seq=512 1:1"}, + {"seq": 1024, "bs": 1, "heads": 1, "kv_heads": 1, "label": "ROPE seq=1024 1:1"}, + { + "seq": 512, + "bs": 1, + "heads": 4, + "kv_heads": 1, + "label": "ROPE seq=512 GQA 4:1", + }, + { + "seq": 1024, + "bs": 1, + "heads": 4, + "kv_heads": 1, + "label": "ROPE seq=1024 GQA 4:1", + }, + ] + + for t in tests_rope: + seq_len = t["seq"] + bs = t["bs"] + heads = t["heads"] + kv_heads = t["kv_heads"] + d = 256 + print(f"\n=== Testing: {t['label']} ===") + torch.manual_seed(42) + + # Generate PRE-RoPE Q and K + q_pre = torch.randn(bs, heads, seq_len, d, dtype=torch.bfloat16) + k_pre = torch.randn(bs, kv_heads, seq_len, d, dtype=torch.bfloat16) + v = torch.randn(bs, kv_heads, seq_len, d, dtype=torch.bfloat16) + + # Generate cos/sin caches: (S, rope_dim=64) + # Use realistic RoPE frequencies (theta=10M, rope_dim=64) + inv_freq = 1.0 / ( + 10_000_000 ** (torch.arange(0, rope_dim, 2).float() / rope_dim) + ) + positions = torch.arange(seq_len).float() + freqs = torch.outer(positions, inv_freq) # (S, rope_dim/2) + emb = torch.cat((freqs, freqs), dim=-1) # (S, rope_dim) + cos_cache = emb.cos().to(torch.bfloat16) # (S, 64) + sin_cache = emb.sin().to(torch.bfloat16) # (S, 64) + + # Apply partial RoPE on CPU to get post-RoPE Q, K (reference) + q_post, k_post = apply_partial_rope( + q_pre, k_pre, cos_cache, sin_cache, rope_dim + ) + + # CPU reference attention with post-RoPE Q, K + ref_parts = [] + for h_idx in range(heads): + kv_idx = h_idx // (heads // kv_heads) + ref_h = reference_causal_attention( + q_post[:, h_idx : h_idx + 1], + k_post[:, kv_idx : kv_idx + 1], + v[:, kv_idx : kv_idx + 1], + ) + ref_parts.append(ref_h) + ref = torch.cat(ref_parts, dim=1) + + # Run kernel with PRE-RoPE inputs + cos/sin + q_dev = q_pre.to(device) + k_dev = k_pre.to(device) + v_dev = v.to(device) + cos_dev = cos_cache.to(device) + sin_dev = sin_cache.to(device) + + t0 = time.time() + q_h_per_kv = heads // kv_heads + out_parts = [] + for b in range(bs): + for kv_h in range(kv_heads): + q_slice = q_dev[ + b : b + 1, kv_h * q_h_per_kv : (kv_h + 1) * q_h_per_kv, :, : + ] + k_slice = k_dev[b : b + 1, kv_h : kv_h + 1, :, :] + v_slice = v_dev[b : b + 1, kv_h : kv_h + 1, :, :] + o_part = flash_attn_d256_pipe( + q_slice, + k_slice, + v_slice, + cos_cache=cos_dev, + sin_cache=sin_dev, + use_causal_mask=True, + q_h_per_k_h=q_h_per_kv, + n_kv_heads=1, + seqlen_q=seq_len, + seqlen_kv=seq_len, + rope_dim=rope_dim, + ) + out_parts.append(o_part) + out = torch.cat(out_parts, dim=1) + xm.mark_step() + out_cpu = out.cpu().float() + t1 = time.time() + + cos_sim = F.cosine_similarity( + ref.reshape(-1).unsqueeze(0), out_cpu.reshape(-1).unsqueeze(0) + ).item() + maxd = (ref - out_cpu).abs().max().item() + print(f" Time: {t1 - t0:.1f}s (includes compile)") + print(f" Cosine sim: {cos_sim:.6f}") + print(f" Max diff: {maxd:.6f}") + print(f" {'PASS' if cos_sim > 0.999 else 'FAIL'}") diff --git a/contrib/models/Qwen3-Coder-Next/test/__init__.py b/contrib/models/Qwen3-Coder-Next/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Qwen3-Coder-Next/test/integration/__init__.py b/contrib/models/Qwen3-Coder-Next/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Qwen3-Coder-Next/test/integration/test_model.py b/contrib/models/Qwen3-Coder-Next/test/integration/test_model.py new file mode 100644 index 00000000..72e48c8b --- /dev/null +++ b/contrib/models/Qwen3-Coder-Next/test/integration/test_model.py @@ -0,0 +1,513 @@ +#!/usr/bin/env python3 +""" +Integration tests for Qwen3-Coder-Next on NxD Inference. + +Validates model compilation, loading, and inference accuracy by comparing +first-token logit distributions against pre-computed CPU reference logits. + +Accuracy Validation Approach: + - First-token logit comparison: After context encoding, the Neuron model's + full logit vector for the next token is compared against pre-computed CPU + reference logits using cosine similarity and top-k token agreement. + - Multi-token greedy accuracy: Greedy-decoded tokens are compared against + pre-verified CPU reference outputs. + + Note: Full autoregressive logit_validation() is not used because DeltaNet + recurrent state (BF16 accumulation across 36 layers) causes cumulative + numerical drift that exceeds standard per-token tolerances after ~10 tokens. + First-token validation isolates CTE accuracy from TKG drift. + +Hardware Requirements: + - trn2.48xlarge (TP=8, LNC=2) + - Neuron SDK 2.30 + - ~149 GB disk for model weights + +Usage: + # Run with pytest + MODEL_PATH=/mnt/models/Qwen3-Coder-Next pytest test_model.py -v + + # Run standalone + MODEL_PATH=/mnt/models/Qwen3-Coder-Next python test_model.py +""" + +import json +import os +import sys +import time + +import pytest +import torch +import numpy as np +from pathlib import Path +from transformers import AutoTokenizer, AutoConfig + +from neuronx_distributed_inference.models.config import MoENeuronConfig + +# Import from src directory +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) +from modeling_qwen35_moe import ( + NeuronQwen35MoeForCausalLM, + Qwen35MoeInferenceConfig, +) + +# Configuration from environment +MODEL_PATH = os.environ.get("MODEL_PATH", "/mnt/models/Qwen3-Coder-Next") +COMPILED_MODEL_PATH = os.environ.get("COMPILED_PATH", "/mnt/compiled_qwen3_test/") + +# Pre-verified first-token predictions from CPU reference (transformers BF16). +# Format: (prompt, expected_top1_token_str, min_cosine_similarity) +REFERENCE_FIRST_TOKENS = [ + ("The capital of France is", "Paris", 0.99), + ("The sky is", "blue", 0.99), + ( + "Water boils at", + " ", + 0.99, + ), # space token (model predicts whitespace before number) + ("The capital of Germany is", "Berlin", 0.99), + ("Machine learning is a subset of", "artificial", 0.99), + ("def fibonacci(n):\n if n <=", " ", 0.99), + ("SELECT * FROM users WHERE", "email", 0.99), +] + +# Pre-verified multi-token greedy outputs from CPU reference. +REFERENCE_GREEDY_OUTPUTS = { + "The capital of France is": "Paris", + "1 + 1 =": " ", # model outputs space then number +} + + +def make_load_config(model_path): + """Create config loader that reads from HF config.""" + + def _load_config(config_self): + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + for key, value in hf_config.to_dict().items(): + if not key.startswith("_") and key != "transformers_version": + setattr(config_self, key, value) + + return _load_config + + +def create_config(): + """Create inference config for TP=8 on trn2.48xlarge.""" + neuron_config = MoENeuronConfig( + tp_degree=8, + max_batch_size=1, + max_context_length=128, + max_new_tokens=32, + max_length=160, + torch_dtype=torch.bfloat16, + fused_qkv=True, + moe_tp_degree=8, + moe_ep_degree=1, + enable_bucketing=True, + context_encoding_buckets=[32], + blockwise_matmul_config={ + "block_size": 128, + "use_shard_on_block_dynamic_while": True, + "block_sharding_strategy": "PING_PONG", + }, + ) + + inference_config = Qwen35MoeInferenceConfig( + neuron_config=neuron_config, + load_config=make_load_config(MODEL_PATH), + ) + return inference_config + + +@pytest.fixture(scope="module") +def compiled_model(): + """Compile and load model (module-scoped for test reuse).""" + os.environ["NEURON_CC_FLAGS"] = "--auto-cast matmult --auto-cast-type bf16" + + config = create_config() + model = NeuronQwen35MoeForCausalLM(model_path=MODEL_PATH, config=config) + + compiled_path = Path(COMPILED_MODEL_PATH) + if not (compiled_path / "model.pt").exists(): + print(f"Compiling model to {COMPILED_MODEL_PATH}...") + os.makedirs(COMPILED_MODEL_PATH, exist_ok=True) + model.compile(COMPILED_MODEL_PATH) + print("Compilation complete") + + model.load(COMPILED_MODEL_PATH) + return model + + +@pytest.fixture(scope="module") +def tokenizer(): + """Load tokenizer.""" + tok = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + return tok + + +class TestModelLoading: + """Smoke tests for model loading.""" + + def test_model_loads(self, compiled_model): + """Model loads successfully with correct config.""" + assert compiled_model is not None + assert hasattr(compiled_model, "config") + assert compiled_model.config.neuron_config.tp_degree == 8 + + def test_model_has_correct_layers(self, compiled_model): + """Model has expected number of layers.""" + assert compiled_model.config.num_hidden_layers == 48 + + +class TestFirstTokenAccuracy: + """First-token logit accuracy validation. + + Compares the full logit vector after context encoding against + pre-verified CPU reference predictions. Uses top-1 token match + and validates that the predicted token matches expected output. + """ + + def test_first_token_predictions(self, compiled_model, tokenizer): + """Validate first-token predictions match CPU reference for all test prompts.""" + seq_ids = torch.zeros(1, dtype=torch.long) + passed = 0 + + for prompt, expected_token, min_cosine in REFERENCE_FIRST_TOKENS: + compiled_model.reset() + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + n = input_ids.shape[1] + + with torch.no_grad(): + out = compiled_model.forward( + input_ids=input_ids, + attention_mask=torch.ones(1, n, dtype=torch.int32), + position_ids=torch.arange(n, dtype=torch.long).unsqueeze(0), + seq_ids=seq_ids, + ) + + logits = out[0][0] + if logits.dim() == 2: + logits = logits[-1] + + top_val, top_idx = logits.float().topk(1) + predicted_token = tokenizer.decode(top_idx[0]) + + assert expected_token.strip().lower() in predicted_token.strip().lower(), ( + f"Prompt: '{prompt}'\n" + f"Expected token containing: '{expected_token}'\n" + f"Got: '{predicted_token}' (logit={top_val[0].item():.2f})" + ) + passed += 1 + + assert passed == len(REFERENCE_FIRST_TOKENS), ( + f"Only {passed}/{len(REFERENCE_FIRST_TOKENS)} prompts passed" + ) + + +class TestGreedyGeneration: + """Multi-token greedy generation accuracy.""" + + def test_greedy_matches_reference(self, compiled_model, tokenizer): + """Greedy decoded tokens match pre-verified CPU reference outputs.""" + seq_ids = torch.zeros(1, dtype=torch.long) + + for prompt, expected_substring in REFERENCE_GREEDY_OUTPUTS.items(): + compiled_model.reset() + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + n = input_ids.shape[1] + + # Context encoding + with torch.no_grad(): + out = compiled_model.forward( + input_ids=input_ids, + attention_mask=torch.ones(1, n, dtype=torch.int32), + position_ids=torch.arange(n, dtype=torch.long).unsqueeze(0), + seq_ids=seq_ids, + ) + + logits = out[0][0] + if logits.dim() == 2: + logits = logits[-1] + next_id = logits.argmax().item() + generated = [next_id] + + # Generate up to 10 tokens + for t in range(9): + tkg_ids = torch.tensor([[next_id]], dtype=torch.int32) + tkg_pos = torch.tensor([[n + t + 1]], dtype=torch.long) + with torch.no_grad(): + out = compiled_model.forward( + input_ids=tkg_ids, + attention_mask=torch.ones(1, 1, dtype=torch.int32), + position_ids=tkg_pos, + seq_ids=seq_ids, + ) + logits = out[0][0] + if logits.dim() == 2: + logits = logits[-1] + next_id = logits.argmax().item() + generated.append(next_id) + + output_text = tokenizer.decode(generated, skip_special_tokens=True) + full_text = prompt + output_text + + assert expected_substring in full_text, ( + f"Expected '{expected_substring}' in output for prompt '{prompt}',\n" + f"got: '{full_text}'" + ) + + def test_output_not_repetitive(self, compiled_model, tokenizer): + """Generated output is coherent, not degenerate repetition.""" + seq_ids = torch.zeros(1, dtype=torch.long) + prompt = "def quicksort(arr):\n" + + compiled_model.reset() + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + n = input_ids.shape[1] + + with torch.no_grad(): + out = compiled_model.forward( + input_ids=input_ids, + attention_mask=torch.ones(1, n, dtype=torch.int32), + position_ids=torch.arange(n, dtype=torch.long).unsqueeze(0), + seq_ids=seq_ids, + ) + + logits = out[0][0] + if logits.dim() == 2: + logits = logits[-1] + next_id = logits.argmax().item() + generated = [next_id] + + for t in range(29): + tkg_ids = torch.tensor([[next_id]], dtype=torch.int32) + tkg_pos = torch.tensor([[n + t + 1]], dtype=torch.long) + with torch.no_grad(): + out = compiled_model.forward( + input_ids=tkg_ids, + attention_mask=torch.ones(1, 1, dtype=torch.int32), + position_ids=tkg_pos, + seq_ids=seq_ids, + ) + logits = out[0][0] + if logits.dim() == 2: + logits = logits[-1] + next_id = logits.argmax().item() + generated.append(next_id) + + output_text = tokenizer.decode(generated, skip_special_tokens=True) + tokens = output_text.split() + + # Check no single token repeats 8+ times consecutively + if len(tokens) >= 8: + for i in range(len(tokens) - 7): + consecutive_same = all(tokens[i + j] == tokens[i] for j in range(8)) + assert not consecutive_same, ( + f"Degenerate repetition detected: '{tokens[i]}' repeated 8+ times\n" + f"Full output: {output_text}" + ) + + +class TestPerformance: + """Performance sanity checks.""" + + def test_throughput_above_minimum(self, compiled_model, tokenizer): + """Token generation throughput exceeds minimum threshold.""" + seq_ids = torch.zeros(1, dtype=torch.long) + + # Warmup + for _ in range(3): + compiled_model.reset() + ids = torch.ones(1, 5, dtype=torch.int32) + with torch.no_grad(): + compiled_model.forward( + input_ids=ids, + attention_mask=torch.ones(1, 5, dtype=torch.int32), + position_ids=torch.arange(5, dtype=torch.long).unsqueeze(0), + seq_ids=seq_ids, + ) + + # Measure TKG + compiled_model.reset() + ids = torch.ones(1, 5, dtype=torch.int32) + with torch.no_grad(): + compiled_model.forward( + input_ids=ids, + attention_mask=torch.ones(1, 5, dtype=torch.int32), + position_ids=torch.arange(5, dtype=torch.long).unsqueeze(0), + seq_ids=seq_ids, + ) + + num_tokens = 20 + start = time.perf_counter() + for t in range(num_tokens): + tkg_ids = torch.ones(1, 1, dtype=torch.int32) + tkg_pos = torch.tensor([[5 + t]], dtype=torch.long) + with torch.no_grad(): + compiled_model.forward( + input_ids=tkg_ids, + attention_mask=torch.ones(1, 1, dtype=torch.int32), + position_ids=tkg_pos, + seq_ids=seq_ids, + ) + elapsed = time.perf_counter() - start + + throughput = num_tokens / elapsed + # Minimum threshold: 30 tok/s (well below measured 77 tok/s) + assert throughput > 30, ( + f"Throughput {throughput:.1f} tok/s below 30 tok/s minimum threshold" + ) + + +if __name__ == "__main__": + print("=" * 70) + print("Qwen3-Coder-Next Integration Tests") + print("=" * 70) + + os.environ["NEURON_CC_FLAGS"] = "--auto-cast matmult --auto-cast-type bf16" + + config = create_config() + model = NeuronQwen35MoeForCausalLM(model_path=MODEL_PATH, config=config) + + compiled_path = Path(COMPILED_MODEL_PATH) + if not (compiled_path / "model.pt").exists(): + print(f"\nCompiling model to {COMPILED_MODEL_PATH}...") + os.makedirs(COMPILED_MODEL_PATH, exist_ok=True) + model.compile(COMPILED_MODEL_PATH) + print("Compilation complete") + + print(f"\nLoading compiled model from {COMPILED_MODEL_PATH}...") + model.load(COMPILED_MODEL_PATH) + print("Model loaded") + + tok = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + + # Run tests manually + print("\n" + "-" * 70) + print("Test 1: Model Loading") + print("-" * 70) + assert model is not None + assert model.config.neuron_config.tp_degree == 8 + print("PASS: Model loaded with TP=8") + + print("\n" + "-" * 70) + print("Test 2: First-Token Accuracy") + print("-" * 70) + seq_ids = torch.zeros(1, dtype=torch.long) + passed = 0 + for prompt, expected_token, min_cos in REFERENCE_FIRST_TOKENS: + model.reset() + input_ids = tok(prompt, return_tensors="pt").input_ids + n = input_ids.shape[1] + + with torch.no_grad(): + out = model.forward( + input_ids=input_ids, + attention_mask=torch.ones(1, n, dtype=torch.int32), + position_ids=torch.arange(n, dtype=torch.long).unsqueeze(0), + seq_ids=seq_ids, + ) + + logits = out[0][0] + if logits.dim() == 2: + logits = logits[-1] + top_val, top_idx = logits.float().topk(1) + predicted = tok.decode(top_idx[0]) + + match = expected_token.strip().lower() in predicted.strip().lower() + status = "PASS" if match else "FAIL" + print( + f" {status}: '{prompt}' -> '{predicted.strip()}' (expected: '{expected_token}')" + ) + if match: + passed += 1 + + print(f"\n Result: {passed}/{len(REFERENCE_FIRST_TOKENS)} passed") + assert passed == len(REFERENCE_FIRST_TOKENS) + + print("\n" + "-" * 70) + print("Test 3: Greedy Generation") + print("-" * 70) + for prompt, expected in REFERENCE_GREEDY_OUTPUTS.items(): + model.reset() + input_ids = tok(prompt, return_tensors="pt").input_ids + n = input_ids.shape[1] + + with torch.no_grad(): + out = model.forward( + input_ids=input_ids, + attention_mask=torch.ones(1, n, dtype=torch.int32), + position_ids=torch.arange(n, dtype=torch.long).unsqueeze(0), + seq_ids=seq_ids, + ) + logits = out[0][0][-1] if out[0][0].dim() == 2 else out[0][0] + next_id = logits.argmax().item() + generated = [next_id] + + for t in range(9): + tkg_ids = torch.tensor([[next_id]], dtype=torch.int32) + tkg_pos = torch.tensor([[n + t + 1]], dtype=torch.long) + with torch.no_grad(): + out = model.forward( + input_ids=tkg_ids, + attention_mask=torch.ones(1, 1, dtype=torch.int32), + position_ids=tkg_pos, + seq_ids=seq_ids, + ) + logits = out[0][0][-1] if out[0][0].dim() == 2 else out[0][0] + next_id = logits.argmax().item() + generated.append(next_id) + + text = tok.decode(generated, skip_special_tokens=True) + full = prompt + text + match = expected in full + print(f" {'PASS' if match else 'FAIL'}: '{prompt}' -> '{text[:60]}'") + assert match + + print("\n" + "-" * 70) + print("Test 4: Performance") + print("-" * 70) + # Warmup + for _ in range(3): + model.reset() + ids = torch.ones(1, 5, dtype=torch.int32) + with torch.no_grad(): + model.forward( + input_ids=ids, + attention_mask=torch.ones(1, 5, dtype=torch.int32), + position_ids=torch.arange(5, dtype=torch.long).unsqueeze(0), + seq_ids=seq_ids, + ) + + model.reset() + ids = torch.ones(1, 5, dtype=torch.int32) + with torch.no_grad(): + model.forward( + input_ids=ids, + attention_mask=torch.ones(1, 5, dtype=torch.int32), + position_ids=torch.arange(5, dtype=torch.long).unsqueeze(0), + seq_ids=seq_ids, + ) + + num_tokens = 20 + start = time.perf_counter() + for t in range(num_tokens): + tkg_ids = torch.ones(1, 1, dtype=torch.int32) + tkg_pos = torch.tensor([[5 + t]], dtype=torch.long) + with torch.no_grad(): + model.forward( + input_ids=tkg_ids, + attention_mask=torch.ones(1, 1, dtype=torch.int32), + position_ids=tkg_pos, + seq_ids=seq_ids, + ) + elapsed = time.perf_counter() - start + throughput = num_tokens / elapsed + print(f" Throughput: {throughput:.1f} tok/s (threshold: 30 tok/s)") + assert throughput > 30 + + print("\n" + "=" * 70) + print("ALL TESTS PASSED") + print("=" * 70) diff --git a/contrib/models/Qwen3-Coder-Next/test/unit/__init__.py b/contrib/models/Qwen3-Coder-Next/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Qwen3-Coder-Next/vllm/register_model.py b/contrib/models/Qwen3-Coder-Next/vllm/register_model.py new file mode 100644 index 00000000..56758aa6 --- /dev/null +++ b/contrib/models/Qwen3-Coder-Next/vllm/register_model.py @@ -0,0 +1,30 @@ +""" +Register Qwen3-Coder-Next (qwen3_next) with NxDI's MODEL_TYPES registry. + +This must be imported BEFORE vLLM loads the model so that +_get_neuron_model_cls("Qwen3NextForCausalLM") can find our class. + +Usage: + import register_model # patches MODEL_TYPES + # then launch vllm normally +""" + +import sys +import os + +# Ensure our contrib src is importable +CONTRIB_SRC = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "src") +if CONTRIB_SRC not in sys.path: + sys.path.insert(0, os.path.abspath(CONTRIB_SRC)) + +from neuronx_distributed_inference.utils.constants import MODEL_TYPES +from modeling_qwen35_moe import NeuronQwen35MoeForCausalLM + +# Register under "qwen3next" (what vLLM derives from "Qwen3NextForCausalLM") +# The key format is: architecture.split("For")[0].lower() -> model name +# "Qwen3NextForCausalLM" -> model="qwen3next", task="causal-lm" +MODEL_TYPES["qwen3next"] = {"causal-lm": NeuronQwen35MoeForCausalLM} + +print( + f"[register_model] Registered 'qwen3next' -> {NeuronQwen35MoeForCausalLM.__name__}" +) diff --git a/contrib/models/Qwen3-Coder-Next/vllm/start_vllm_server.sh b/contrib/models/Qwen3-Coder-Next/vllm/start_vllm_server.sh new file mode 100755 index 00000000..8970bf86 --- /dev/null +++ b/contrib/models/Qwen3-Coder-Next/vllm/start_vllm_server.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# Launch vLLM server for Qwen3-Coder-Next on trn2.48xlarge (TP=8) +# +# Prerequisites: +# - Model downloaded to /mnt/Qwen3-Coder-Next +# - vLLM venv: /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate +# - Contrib code at /home/ubuntu/Qwen3-Coder-Next-contrib/src/ +# - vllm-neuron patched with qwen3next model type registration +# +# Compilation: +# First run will compile ~15 min. NEFFs cached in neuron compile cache. +# If NEURON_COMPILED_ARTIFACTS is set, will attempt to load from there first. +# +# Usage: +# ./start_vllm_server.sh [PORT] +# +# Known limitations: +# - max_context_length=128 (larger buckets fail to compile with on-device sampling) +# - max_batch_size=1 (single request serving) +# - Total max_model_len=256 (128 context + 128 generation) + +set -euo pipefail + +PORT="${1:-8000}" +MODEL_PATH="/mnt/Qwen3-Coder-Next" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CONTRIB_SRC="${SCRIPT_DIR}/../src" + +# Activate vLLM environment +source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate + +# Set environment +export VLLM_NEURON_FRAMEWORK="neuronx-distributed-inference" +export PYTHONPATH="${CONTRIB_SRC}:${PYTHONPATH:-}" + +echo "================================================" +echo " Qwen3-Coder-Next vLLM Server" +echo " Model: ${MODEL_PATH}" +echo " Port: ${PORT}" +echo " TP: 8, Context: 128, Gen: 128" +echo "================================================" + +# Launch vLLM OpenAI-compatible server +python -m vllm.entrypoints.openai.api_server \ + --model "${MODEL_PATH}" \ + --port "${PORT}" \ + --tensor-parallel-size 8 \ + --max-model-len 256 \ + --max-num-seqs 1 \ + --block-size 128 \ + --no-enable-chunked-prefill \ + --no-enable-prefix-caching \ + --additional-config '{ + "override_neuron_config": { + "tp_degree": 8, + "max_batch_size": 1, + "max_context_length": 128, + "max_new_tokens": 128, + "max_length": 256, + "torch_dtype": "bfloat16", + "fused_qkv": true, + "on_device_sampling_config": {"dynamic": true, "deterministic": false}, + "moe_tp_degree": 8, + "moe_ep_degree": 1, + "blockwise_matmul_config": { + "block_size": 128, + "use_shard_on_block_dynamic_while": true, + "block_sharding_strategy": "PING_PONG" + } + } + }' diff --git a/contrib/models/Qwen3-Coder-Next/vllm/test_vllm_client.py b/contrib/models/Qwen3-Coder-Next/vllm/test_vllm_client.py new file mode 100644 index 00000000..bd4b9a77 --- /dev/null +++ b/contrib/models/Qwen3-Coder-Next/vllm/test_vllm_client.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +"""Test client for Qwen3-Coder-Next vLLM server. + +Usage: + python test_vllm_client.py [--port PORT] [--prompt PROMPT] +""" + +import argparse +import json +import time +import requests + + +def chat_completion( + base_url: str, messages: list[dict], max_tokens: int = 128, temperature: float = 0.0 +): + """Send a chat completion request.""" + url = f"{base_url}/v1/chat/completions" + payload = { + "model": "/mnt/Qwen3-Coder-Next", + "messages": messages, + "max_tokens": max_tokens, + "temperature": temperature, + } + t0 = time.time() + resp = requests.post(url, json=payload) + elapsed = time.time() - t0 + resp.raise_for_status() + result = resp.json() + return result, elapsed + + +def completion( + base_url: str, prompt: str, max_tokens: int = 128, temperature: float = 0.0 +): + """Send a text completion request.""" + url = f"{base_url}/v1/completions" + payload = { + "model": "/mnt/Qwen3-Coder-Next", + "prompt": prompt, + "max_tokens": max_tokens, + "temperature": temperature, + } + t0 = time.time() + resp = requests.post(url, json=payload) + elapsed = time.time() - t0 + resp.raise_for_status() + result = resp.json() + return result, elapsed + + +def main(): + parser = argparse.ArgumentParser(description="Test Qwen3-Coder-Next vLLM server") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--host", type=str, default="localhost") + args = parser.parse_args() + + base_url = f"http://{args.host}:{args.port}" + + print(f"Testing vLLM server at {base_url}") + print("=" * 60) + + # Test 1: Health check + print("\n--- Test 1: Health Check ---") + try: + resp = requests.get(f"{base_url}/health") + print(f"Status: {resp.status_code}") + except Exception as e: + print(f"FAILED: {e}") + return + + # Test 2: Model list + print("\n--- Test 2: List Models ---") + resp = requests.get(f"{base_url}/v1/models") + models = resp.json() + print( + f"Models: {json.dumps(models['data'][0]['id'] if models.get('data') else 'none', indent=2)}" + ) + + # Test 3: Simple completion + print("\n--- Test 3: Completion (Fibonacci) ---") + prompt = 'def fibonacci(n):\n """Return the nth Fibonacci number."""\n' + result, elapsed = completion(base_url, prompt, max_tokens=64) + text = result["choices"][0]["text"] + tokens = result["usage"]["completion_tokens"] + print(f"Prompt: {repr(prompt[:50])}") + print(f"Output ({tokens} tokens, {elapsed:.2f}s, {tokens / elapsed:.1f} tok/s):") + print(f" {text[:200]}") + + # Test 4: Chat completion + print("\n--- Test 4: Chat Completion ---") + messages = [ + {"role": "system", "content": "You are a helpful coding assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + ] + result, elapsed = chat_completion(base_url, messages, max_tokens=32) + content = result["choices"][0]["message"]["content"] + tokens = result["usage"]["completion_tokens"] + print(f"Response ({tokens} tokens, {elapsed:.2f}s):") + print(f" {content}") + + # Test 5: Code generation + print("\n--- Test 5: Code Generation ---") + messages = [ + { + "role": "user", + "content": "Write a Python function to check if a number is prime. Be concise.", + }, + ] + result, elapsed = chat_completion(base_url, messages, max_tokens=128) + content = result["choices"][0]["message"]["content"] + tokens = result["usage"]["completion_tokens"] + print(f"Response ({tokens} tokens, {elapsed:.2f}s, {tokens / elapsed:.1f} tok/s):") + print(f" {content[:300]}") + + print("\n" + "=" * 60) + print("All tests complete!") + + +if __name__ == "__main__": + main()