Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions contrib/models/FlashVSR/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Contrib Model: FlashVSR

Video super-resolution (4x upscaling) on AWS Trainium using a streaming DiT architecture with NKI tiled flash attention.

## Model Information

- **HuggingFace ID:** `JunhaoZhuang/FlashVSR-v1.1`
- **Model Type:** Video super-resolution DiT (Denoising Diffusion Transformer)
- **Parameters:** ~1.3B (BF16) DiT + 288M LQ Projection + 45M TCDecoder
- **Architecture:** 30-layer DiT with factored 3D RoPE, LCSA self-attention, text cross-attention, AdaLN modulation, QK-norm with DistributedRMSNorm
- **Base Model:** Wan 2.1 1.3B (dim=1536, 12 heads, head_dim=128)
- **License:** Check HuggingFace model card

## Validation Results

**Validated:** 2026-05-26
**Instance:** trn2.3xlarge (LNC=2, 4 logical NeuronCores)
**SDK:** Neuron SDK 2.29.1, PyTorch 2.9, NKI 0.3.0

### Benchmark Results

| Metric | Value |
|--------|-------|
| End-to-end throughput | **10.3 FPS** (768x1280 output, 85 frames) |
| Total DiT time | 5.0s (1 first chunk + 8 stream chunks) |
| Total TCDecoder time (NxDI, co-resident) | 2.4s (22 calls × 89ms, HBM state persistence) |
| LQ Projection | 0.86s (single pass, all frames) |
| Model loading | DiT 40s + TCDecoder 1.8s (one-time startup) |

### Accuracy Validation

| Metric | Value |
|--------|-------|
| DiT neuron_allclose vs CPU (rtol=0.05, atol=0.1) | PASS |
| DiT max_rel_error | 0.025 |
| DiT cosine similarity | 0.9997 |
| DiT per-chunk latency (first chunk, f=6) | ~1720 ms |
| DiT per-chunk latency (stream, f=2) | ~410 ms |
| Full pipeline visual quality | Matches reference implementation (DMD single-step) |

## Usage

```python
from src.pipeline import compile_pipeline, load_pipeline, run_inference

# Step 1: Download weights (one-time)
# python -m src.download_weights --output-dir /path/to/FlashVSR-v1.1

# Step 2: Compile models (one-time per resolution)
compile_pipeline(
weights_dir="/path/to/FlashVSR-v1.1",
output_dir="/path/to/compiled",
height=768,
width=1280,
tp_degree=4,
)

# Step 3: Load compiled pipeline
pipeline = load_pipeline(
compiled_dir="/path/to/compiled",
weights_dir="/path/to/FlashVSR-v1.1",
prompt_path="/path/to/FlashVSR-v1.1/posi_prompt.pth",
tp_degree=4,
tcdecoder_path="/path/to/compiled/tcdecoder_seq.pt",
lq_proj_path="/path/to/compiled/lq_proj.pt",
)

# Step 4: Run inference
output_path = run_inference(
pipeline,
input_video="/path/to/input.mp4",
output_dir="/path/to/output",
scale=4,
)
```

## Pipeline Architecture

FlashVSR has three separately compiled Neuron components, all co-resident in HBM:

| Component | Compilation Method | TP Degree | Role |
|-----------|-------------------|-----------|------|
| DiT (first chunk) | NxDI ModelBuilder | TP=4 | Denoising, f=6 latent frames |
| DiT (stream chunk) | NxDI ModelBuilder | TP=4 | Denoising, f=2 latent frames |
| LQ Projection | torch_neuronx.trace | TP=1 | Generates conditioning tokens |
| TCDecoder | NxDI ModelBuilder | TP=4 | Latent-to-RGB (HBM state persistence) |

All models are loaded at startup and remain co-resident in HBM (total ~15 GB out of 96 GB available on trn2.3xlarge). This eliminates model transition overhead between pipeline stages.

The streaming architecture processes video in chunks: first chunk (6 latent frames = 24 output frames) followed by overlapping stream chunks (2 latent frames = 8 output frames each).

## Key Technical Details

- **NKI Flash Attention:** Uses `attention_cte` from nkilib -- tiles attention computation in SRAM, never materializes the full S*S attention matrix in HBM. Enables 23040-token sequences on trn2.3xlarge.
- **DistributedRMSNorm:** QK-norm with all-reduce across TP ranks for global variance computation. Essential for accuracy at TP>1.
- **Co-resident HBM models:** DiT (7.5 GB × 2) + TCDecoder (378 MB) all loaded simultaneously in 96 GB HBM. Eliminates model swap overhead between pipeline stages.
- **TCDecoder HBM State Persistence:** Uses `input_output_aliases` to keep 9 MemBlock states in device HBM between sequential calls. No PCIe state transfer per frame. Output reshaped from `(4, 3, H, W)` to `(1, 12, H, W)` inside the NEFF to prevent TP from sharding the temporal dimension across ranks.
- **Phase 2 LCSA (optional):** Block-sparse Locality-Constrained Sparse Attention behind `USE_BLOCK_SPARSE_LCSA` toggle. Generates per-layer sparse masks inside the traced graph via topk + index_select. Requires trn2.48xlarge with TP=16.
- **Single-step DMD:** FlashVSR-v1.1 uses Distribution Matching Distillation for single-step denoising (timestep=1000).

## Compatibility Matrix

| Instance/Config | SDK 2.29.1 | SDK 2.29 | SDK 2.28 |
|-----------------|------------|----------|----------|
| trn2.3xlarge, TP=4, LNC=2 | **VALIDATED (10.3 FPS)** | VALIDATED (8.3 FPS) | Not tested |

## Example Checkpoints

* [JunhaoZhuang/FlashVSR-v1.1](https://huggingface.co/JunhaoZhuang/FlashVSR-v1.1)

## Testing Instructions

```bash
# Run DiT accuracy test (neuron_allclose vs CPU reference)
pytest test/integration/test_dit_accuracy.py -v

# Run full pipeline E2E test (PSNR validation)
pytest test/integration/test_pipeline_e2e.py -v
```

## Known Issues

- **Resolution constraint:** Input video must produce output dimensions divisible by 128 (e.g., 768x1280). Other resolutions require recompilation.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you re-compile for all resolutions? Compiling for a resolution larger than one you've validated might result in a compilation error for having too many instructions.

- **Phase 2 LCSA:** Block-sparse attention requires trn2.48xlarge with TP=16 (not available on trn2.3xlarge). Production uses Phase 1 dense attention.
- **TCDecoder temporal recurrence:** Each frame must be processed serially due to MemBlock temporal dependencies. The NxDI HBM state persistence approach minimizes this cost (89ms/call vs 237ms with PCIe state transfer).
- **Text embedding:** Uses a pre-computed positive prompt embedding (`posi_prompt.pth`). Custom prompts require running the T5 text encoder separately.

## Maintainer

Jim Burtoft
28 changes: 28 additions & 0 deletions contrib/models/FlashVSR/src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""FlashVSR: Video Super-Resolution on AWS Trainium using NxD Inference."""

from .modeling_flashvsr import NeuronFlashVSRDiT, FlashVSRDiTConfig
from .tcdecoder import (
NeuronTCDecoderSequential,
NeuronTCDecoderStateful,
TCDecoderApplication,
decode_video_nxdi,
)
from .lq_projection import NeuronLQProj
from .pipeline import FlashVSRPipeline, compile_pipeline, load_pipeline, run_inference

__all__ = [
"NeuronFlashVSRDiT",
"FlashVSRDiTConfig",
"NeuronTCDecoderSequential",
"NeuronTCDecoderStateful",
"TCDecoderApplication",
"decode_video_nxdi",
"NeuronLQProj",
"FlashVSRPipeline",
"compile_pipeline",
"load_pipeline",
"run_inference",
]
106 changes: 106 additions & 0 deletions contrib/models/FlashVSR/src/download_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#!/usr/bin/env python3
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Download and prepare FlashVSR-v1.1 weights for Neuron inference.

Downloads the model from HuggingFace and organizes weights into the expected
directory structure for the FlashVSR Neuron pipeline.

Usage:
python -m src.download_weights --output-dir /path/to/FlashVSR-v1.1

Required files from HuggingFace:
- JunhaoZhuang/FlashVSR-v1.1:
- diffusion_pytorch_model_streaming_dmd.safetensors (DiT weights)
- LQ_proj_in.ckpt (LQ projection weights)
- TCDecoder.ckpt (TCDecoder weights)
- posi_prompt.pth (pre-computed text embedding)
"""

import argparse
import os
import sys


def download_weights(output_dir: str, token: str = None):
"""Download FlashVSR-v1.1 weights from HuggingFace.

Args:
output_dir: Directory to save weights
token: HuggingFace access token (if model is gated)
"""
try:
from huggingface_hub import hf_hub_download, snapshot_download
except ImportError:
print("ERROR: huggingface_hub not installed. Run: pip install huggingface_hub")
sys.exit(1)

os.makedirs(output_dir, exist_ok=True)

repo_id = "JunhaoZhuang/FlashVSR-v1.1"
print(f"Downloading FlashVSR-v1.1 weights from {repo_id}...")

# Required files
required_files = [
"diffusion_pytorch_model_streaming_dmd.safetensors",
"LQ_proj_in.ckpt",
"TCDecoder.ckpt",
"posi_prompt.pth",
]

for filename in required_files:
target = os.path.join(output_dir, filename)
if os.path.exists(target):
print(f" [SKIP] {filename} already exists")
continue

print(f" Downloading {filename}...")
try:
hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=output_dir,
token=token,
)
print(f" [OK] {filename}")
except Exception as e:
print(f" [WARN] Failed to download {filename}: {e}")

# Create symlink for NxDI checkpoint loader compatibility
actual = os.path.join(
output_dir, "diffusion_pytorch_model_streaming_dmd.safetensors"
)
symlink = os.path.join(output_dir, "diffusion_pytorch_model.safetensors")
if os.path.exists(actual) and not os.path.exists(symlink):
os.symlink(os.path.basename(actual), symlink)
print(f" Created symlink: diffusion_pytorch_model.safetensors")

print(f"\nWeights saved to: {output_dir}")
print(f"Contents:")
for f in sorted(os.listdir(output_dir)):
size_mb = os.path.getsize(os.path.join(output_dir, f)) / 1024 / 1024
print(f" {f} ({size_mb:.1f} MB)")


def main():
parser = argparse.ArgumentParser(description="Download FlashVSR-v1.1 weights")
parser.add_argument(
"--output-dir",
type=str,
required=True,
help="Directory to save weights",
)
parser.add_argument(
"--token",
type=str,
default=None,
help="HuggingFace access token (if model is gated)",
)
args = parser.parse_args()
download_weights(args.output_dir, args.token)


if __name__ == "__main__":
main()
Loading