-
Notifications
You must be signed in to change notification settings - Fork 37
Add FlashVSR contrib model with video super-resolution on Neuron #165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jimburtoft
wants to merge
11
commits into
aws-neuron:main
Choose a base branch
from
jimburtoft:contrib/flashvsr
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 5 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
75981fd
Add FlashVSR contrib model with video super-resolution on Neuron
jimburtoft da89d16
Fix FlashVSR README: correct benchmark numbers, remove GPU reference
jimburtoft 26bd031
Update FlashVSR accuracy test tolerances based on hardware validation
jimburtoft 6db2715
Add NxDI TCDecoder with HBM state persistence (3.0x speedup)
jimburtoft 0434512
TCDecoder: co-resident TP=4 with output reshape fix (10.3 FPS)
jimburtoft ef3eabe
Add E2E FlashVSR notebook with Neuron AdaIN color correction (9.8 FPS)
jimburtoft 72d94ed
Address review feedback: narrow exception, guard private attr, add co…
jimburtoft e3eb76b
Add troubleshooting guide and improved repro instructions to README
jimburtoft 3464ba1
Add multi-bucket stream support for long-video optimization
jimburtoft 01dfc29
Add temporal_offset param to neuron_dit_forward and multi-bucket test
jimburtoft 1662f49
Document multi-bucket benchmark results: larger buckets are slower fo…
jimburtoft File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,130 @@ | ||
| # Contrib Model: FlashVSR | ||
|
|
||
| Video super-resolution (4x upscaling) on AWS Trainium using a streaming DiT architecture with NKI tiled flash attention. | ||
|
|
||
| ## Model Information | ||
|
|
||
| - **HuggingFace ID:** `JunhaoZhuang/FlashVSR-v1.1` | ||
| - **Model Type:** Video super-resolution DiT (Denoising Diffusion Transformer) | ||
| - **Parameters:** ~1.3B (BF16) DiT + 288M LQ Projection + 45M TCDecoder | ||
| - **Architecture:** 30-layer DiT with factored 3D RoPE, LCSA self-attention, text cross-attention, AdaLN modulation, QK-norm with DistributedRMSNorm | ||
| - **Base Model:** Wan 2.1 1.3B (dim=1536, 12 heads, head_dim=128) | ||
| - **License:** Check HuggingFace model card | ||
|
|
||
| ## Validation Results | ||
|
|
||
| **Validated:** 2026-05-26 | ||
| **Instance:** trn2.3xlarge (LNC=2, 4 logical NeuronCores) | ||
| **SDK:** Neuron SDK 2.29.1, PyTorch 2.9, NKI 0.3.0 | ||
|
|
||
| ### Benchmark Results | ||
|
|
||
| | Metric | Value | | ||
| |--------|-------| | ||
| | End-to-end throughput | **10.3 FPS** (768x1280 output, 85 frames) | | ||
| | Total DiT time | 5.0s (1 first chunk + 8 stream chunks) | | ||
| | Total TCDecoder time (NxDI, co-resident) | 2.4s (22 calls × 89ms, HBM state persistence) | | ||
| | LQ Projection | 0.86s (single pass, all frames) | | ||
| | Model loading | DiT 40s + TCDecoder 1.8s (one-time startup) | | ||
|
|
||
| ### Accuracy Validation | ||
|
|
||
| | Metric | Value | | ||
| |--------|-------| | ||
| | DiT neuron_allclose vs CPU (rtol=0.05, atol=0.1) | PASS | | ||
| | DiT max_rel_error | 0.025 | | ||
| | DiT cosine similarity | 0.9997 | | ||
| | DiT per-chunk latency (first chunk, f=6) | ~1720 ms | | ||
| | DiT per-chunk latency (stream, f=2) | ~410 ms | | ||
| | Full pipeline visual quality | Matches reference implementation (DMD single-step) | | ||
|
|
||
| ## Usage | ||
|
|
||
| ```python | ||
| from src.pipeline import compile_pipeline, load_pipeline, run_inference | ||
|
|
||
| # Step 1: Download weights (one-time) | ||
| # python -m src.download_weights --output-dir /path/to/FlashVSR-v1.1 | ||
|
|
||
| # Step 2: Compile models (one-time per resolution) | ||
| compile_pipeline( | ||
| weights_dir="/path/to/FlashVSR-v1.1", | ||
| output_dir="/path/to/compiled", | ||
| height=768, | ||
| width=1280, | ||
| tp_degree=4, | ||
| ) | ||
|
|
||
| # Step 3: Load compiled pipeline | ||
| pipeline = load_pipeline( | ||
| compiled_dir="/path/to/compiled", | ||
| weights_dir="/path/to/FlashVSR-v1.1", | ||
| prompt_path="/path/to/FlashVSR-v1.1/posi_prompt.pth", | ||
| tp_degree=4, | ||
| tcdecoder_path="/path/to/compiled/tcdecoder_seq.pt", | ||
| lq_proj_path="/path/to/compiled/lq_proj.pt", | ||
| ) | ||
|
|
||
| # Step 4: Run inference | ||
| output_path = run_inference( | ||
| pipeline, | ||
| input_video="/path/to/input.mp4", | ||
| output_dir="/path/to/output", | ||
| scale=4, | ||
| ) | ||
| ``` | ||
|
|
||
| ## Pipeline Architecture | ||
|
|
||
| FlashVSR has three separately compiled Neuron components, all co-resident in HBM: | ||
|
|
||
| | Component | Compilation Method | TP Degree | Role | | ||
| |-----------|-------------------|-----------|------| | ||
| | DiT (first chunk) | NxDI ModelBuilder | TP=4 | Denoising, f=6 latent frames | | ||
| | DiT (stream chunk) | NxDI ModelBuilder | TP=4 | Denoising, f=2 latent frames | | ||
| | LQ Projection | torch_neuronx.trace | TP=1 | Generates conditioning tokens | | ||
| | TCDecoder | NxDI ModelBuilder | TP=4 | Latent-to-RGB (HBM state persistence) | | ||
|
|
||
| All models are loaded at startup and remain co-resident in HBM (total ~15 GB out of 96 GB available on trn2.3xlarge). This eliminates model transition overhead between pipeline stages. | ||
|
|
||
| The streaming architecture processes video in chunks: first chunk (6 latent frames = 24 output frames) followed by overlapping stream chunks (2 latent frames = 8 output frames each). | ||
|
|
||
| ## Key Technical Details | ||
|
|
||
| - **NKI Flash Attention:** Uses `attention_cte` from nkilib -- tiles attention computation in SRAM, never materializes the full S*S attention matrix in HBM. Enables 23040-token sequences on trn2.3xlarge. | ||
| - **DistributedRMSNorm:** QK-norm with all-reduce across TP ranks for global variance computation. Essential for accuracy at TP>1. | ||
| - **Co-resident HBM models:** DiT (7.5 GB × 2) + TCDecoder (378 MB) all loaded simultaneously in 96 GB HBM. Eliminates model swap overhead between pipeline stages. | ||
| - **TCDecoder HBM State Persistence:** Uses `input_output_aliases` to keep 9 MemBlock states in device HBM between sequential calls. No PCIe state transfer per frame. Output reshaped from `(4, 3, H, W)` to `(1, 12, H, W)` inside the NEFF to prevent TP from sharding the temporal dimension across ranks. | ||
| - **Phase 2 LCSA (optional):** Block-sparse Locality-Constrained Sparse Attention behind `USE_BLOCK_SPARSE_LCSA` toggle. Generates per-layer sparse masks inside the traced graph via topk + index_select. Requires trn2.48xlarge with TP=16. | ||
| - **Single-step DMD:** FlashVSR-v1.1 uses Distribution Matching Distillation for single-step denoising (timestep=1000). | ||
|
|
||
| ## Compatibility Matrix | ||
|
|
||
| | Instance/Config | SDK 2.29.1 | SDK 2.29 | SDK 2.28 | | ||
| |-----------------|------------|----------|----------| | ||
| | trn2.3xlarge, TP=4, LNC=2 | **VALIDATED (10.3 FPS)** | VALIDATED (8.3 FPS) | Not tested | | ||
|
|
||
| ## Example Checkpoints | ||
|
|
||
| * [JunhaoZhuang/FlashVSR-v1.1](https://huggingface.co/JunhaoZhuang/FlashVSR-v1.1) | ||
|
|
||
| ## Testing Instructions | ||
|
|
||
| ```bash | ||
| # Run DiT accuracy test (neuron_allclose vs CPU reference) | ||
| pytest test/integration/test_dit_accuracy.py -v | ||
|
|
||
| # Run full pipeline E2E test (PSNR validation) | ||
| pytest test/integration/test_pipeline_e2e.py -v | ||
| ``` | ||
|
|
||
| ## Known Issues | ||
|
|
||
| - **Resolution constraint:** Input video must produce output dimensions divisible by 128 (e.g., 768x1280). Other resolutions require recompilation. | ||
| - **Phase 2 LCSA:** Block-sparse attention requires trn2.48xlarge with TP=16 (not available on trn2.3xlarge). Production uses Phase 1 dense attention. | ||
| - **TCDecoder temporal recurrence:** Each frame must be processed serially due to MemBlock temporal dependencies. The NxDI HBM state persistence approach minimizes this cost (89ms/call vs 237ms with PCIe state transfer). | ||
| - **Text embedding:** Uses a pre-computed positive prompt embedding (`posi_prompt.pth`). Custom prompts require running the T5 text encoder separately. | ||
|
|
||
| ## Maintainer | ||
|
|
||
| Jim Burtoft | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """FlashVSR: Video Super-Resolution on AWS Trainium using NxD Inference.""" | ||
|
|
||
| from .modeling_flashvsr import NeuronFlashVSRDiT, FlashVSRDiTConfig | ||
| from .tcdecoder import ( | ||
| NeuronTCDecoderSequential, | ||
| NeuronTCDecoderStateful, | ||
| TCDecoderApplication, | ||
| decode_video_nxdi, | ||
| ) | ||
| from .lq_projection import NeuronLQProj | ||
| from .pipeline import FlashVSRPipeline, compile_pipeline, load_pipeline, run_inference | ||
|
|
||
| __all__ = [ | ||
| "NeuronFlashVSRDiT", | ||
| "FlashVSRDiTConfig", | ||
| "NeuronTCDecoderSequential", | ||
| "NeuronTCDecoderStateful", | ||
| "TCDecoderApplication", | ||
| "decode_video_nxdi", | ||
| "NeuronLQProj", | ||
| "FlashVSRPipeline", | ||
| "compile_pipeline", | ||
| "load_pipeline", | ||
| "run_inference", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| #!/usr/bin/env python3 | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """ | ||
| Download and prepare FlashVSR-v1.1 weights for Neuron inference. | ||
|
|
||
| Downloads the model from HuggingFace and organizes weights into the expected | ||
| directory structure for the FlashVSR Neuron pipeline. | ||
|
|
||
| Usage: | ||
| python -m src.download_weights --output-dir /path/to/FlashVSR-v1.1 | ||
|
|
||
| Required files from HuggingFace: | ||
| - JunhaoZhuang/FlashVSR-v1.1: | ||
| - diffusion_pytorch_model_streaming_dmd.safetensors (DiT weights) | ||
| - LQ_proj_in.ckpt (LQ projection weights) | ||
| - TCDecoder.ckpt (TCDecoder weights) | ||
| - posi_prompt.pth (pre-computed text embedding) | ||
| """ | ||
|
|
||
| import argparse | ||
| import os | ||
| import sys | ||
|
|
||
|
|
||
| def download_weights(output_dir: str, token: str = None): | ||
| """Download FlashVSR-v1.1 weights from HuggingFace. | ||
|
|
||
| Args: | ||
| output_dir: Directory to save weights | ||
| token: HuggingFace access token (if model is gated) | ||
| """ | ||
| try: | ||
| from huggingface_hub import hf_hub_download, snapshot_download | ||
| except ImportError: | ||
| print("ERROR: huggingface_hub not installed. Run: pip install huggingface_hub") | ||
| sys.exit(1) | ||
|
|
||
| os.makedirs(output_dir, exist_ok=True) | ||
|
|
||
| repo_id = "JunhaoZhuang/FlashVSR-v1.1" | ||
| print(f"Downloading FlashVSR-v1.1 weights from {repo_id}...") | ||
|
|
||
| # Required files | ||
| required_files = [ | ||
| "diffusion_pytorch_model_streaming_dmd.safetensors", | ||
| "LQ_proj_in.ckpt", | ||
| "TCDecoder.ckpt", | ||
| "posi_prompt.pth", | ||
| ] | ||
|
|
||
| for filename in required_files: | ||
| target = os.path.join(output_dir, filename) | ||
| if os.path.exists(target): | ||
| print(f" [SKIP] {filename} already exists") | ||
| continue | ||
|
|
||
| print(f" Downloading {filename}...") | ||
| try: | ||
| hf_hub_download( | ||
| repo_id=repo_id, | ||
| filename=filename, | ||
| local_dir=output_dir, | ||
| token=token, | ||
| ) | ||
| print(f" [OK] {filename}") | ||
| except Exception as e: | ||
| print(f" [WARN] Failed to download {filename}: {e}") | ||
|
|
||
| # Create symlink for NxDI checkpoint loader compatibility | ||
| actual = os.path.join( | ||
| output_dir, "diffusion_pytorch_model_streaming_dmd.safetensors" | ||
| ) | ||
| symlink = os.path.join(output_dir, "diffusion_pytorch_model.safetensors") | ||
| if os.path.exists(actual) and not os.path.exists(symlink): | ||
| os.symlink(os.path.basename(actual), symlink) | ||
| print(f" Created symlink: diffusion_pytorch_model.safetensors") | ||
|
|
||
| print(f"\nWeights saved to: {output_dir}") | ||
| print(f"Contents:") | ||
| for f in sorted(os.listdir(output_dir)): | ||
| size_mb = os.path.getsize(os.path.join(output_dir, f)) / 1024 / 1024 | ||
| print(f" {f} ({size_mb:.1f} MB)") | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser(description="Download FlashVSR-v1.1 weights") | ||
| parser.add_argument( | ||
| "--output-dir", | ||
| type=str, | ||
| required=True, | ||
| help="Directory to save weights", | ||
| ) | ||
| parser.add_argument( | ||
| "--token", | ||
| type=str, | ||
| default=None, | ||
| help="HuggingFace access token (if model is gated)", | ||
| ) | ||
| args = parser.parse_args() | ||
| download_weights(args.output_dir, args.token) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you re-compile for all resolutions? Compiling for a resolution larger than one you've validated might result in a compilation error for having too many instructions.