diff --git a/contrib/models/flux1-lite-8b/README.md b/contrib/models/flux1-lite-8b/README.md new file mode 100644 index 00000000..5143da9c --- /dev/null +++ b/contrib/models/flux1-lite-8b/README.md @@ -0,0 +1,222 @@ +# Contrib Model: FLUX.1-lite-8B-alpha + +FLUX.1-lite-8B-alpha image generation model running on AWS Neuron using NxDI's first-party FLUX.1 implementation with zero code modifications. + +## Model Information + +- **HuggingFace ID:** `Freepik/flux.1-lite-8B-alpha` +- **Model Type:** Diffusion transformer (DiT) for text-to-image generation +- **Parameters:** ~8B (BF16) +- **Architecture:** 8 double-stream MMDiT blocks + 38 single-stream DiT blocks, CLIP + T5-XXL text encoders, 16-channel VAE, FlowMatchEulerDiscrete scheduler +- **License:** Check HuggingFace model card (gated model, requires access approval) + +## Key Finding: Native NxDI FLUX.1 Compatibility + +**FLUX.1-lite-8B-alpha is architecturally identical to FLUX.1-dev** with only the number of double-stream blocks reduced (8 vs 19). All other components are the same: + +| Component | FLUX.1-dev | FLUX.1-lite-8B | Same? | +|-----------|-----------|----------------|-------| +| Double-stream (MMDiT) blocks | 19 | 8 | Different | +| Single-stream (DiT) blocks | 38 | 38 | Same | +| Attention heads | 24 | 24 | Same | +| Attention head dim | 128 | 128 | Same | +| Joint attention dim | 4096 | 4096 | Same | +| Text encoders | CLIP + T5-XXL | CLIP + T5-XXL | Same | +| VAE latent channels | 16 | 16 | Same | +| RoPE axes_dim | (16, 56, 56) | (16, 56, 56) | Same | +| Pipeline class | FluxPipeline | FluxPipeline | Same | +| Scheduler | FlowMatchEulerDiscrete | FlowMatchEulerDiscrete | Same | +| guidance_embeds | True | True | Same | + +Because NxDI's FLUX.1 implementation reads `num_layers` and `num_single_layers` from the model's `config.json` at runtime (via `load_diffusers_config()`), it automatically adapts to FLUX.1-lite's configuration. **No custom modeling code is needed.** + +This contrib provides: +- A standalone generation script (`src/generate_flux_lite.py`) for 1024x1024 +- A high-resolution script (`src/generate_flux_lite_highres.py`) for 2048x2048 and 4096x4096 +- Integration tests validating correct operation on Neuron +- Benchmark results demonstrating the performance benefit of the lighter architecture + +## Validation Results + +**Validated:** 2026-04-28 +**Instance:** trn2.3xlarge (LNC=2, 4 logical cores) +**SDK:** Neuron SDK 2.29 (DLAMI 20260410), PyTorch 2.9, NxD Inference 0.9 + +### Benchmark Results (1024x1024, 25 steps, guidance_scale=3.5) + +| Metric | Value | +|--------|-------| +| Resolution | 1024x1024 | +| Inference steps | 25 | +| TP Degree | 4 | +| CFG | Guidance distillation (single forward pass/step) | +| E2E generation time | 5.91s avg | +| Pipeline steps/sec | 4.23 | +| Backbone forward/sec | 4.49 | +| Compilation time | ~128s (CLIP 69s + T5 5s + backbone 53s + VAE ~2s) | + +## High-Resolution Generation (2K, 4K) + +> **Note:** The original FLUX.1-lite-8B model was trained and validated at 1024x1024 only. +> The original FLUX.1-dev/schnell models do not natively support 2K or 4K resolution either. +> High-resolution generation is an extrapolation beyond the training distribution — image +> quality may differ from native resolution. This capability is primarily useful for +> customers who have fine-tuned their own models at higher resolutions or want to evaluate +> the architecture's scaling behavior. + +### Results Summary + +| Resolution | Tokens | Latency | Instance | Strategy | +|-----------|--------|---------|----------|----------| +| 1024x1024 | 4,096 | **5.91s** | trn2.3xlarge | TP=4 | +| 2048x2048 | 16,384 | **31.53s** | trn2.3xlarge | TP=4 + tiled VAE (4 tiles) | +| 4096x4096 | 65,536 | **107.25s** | trn2.48xlarge | TP=4, CP=4 + tiled VAE (25 tiles) | + +### How It Works + +**2048x2048 (16,384 tokens):** +- The backbone (transformer) is compiled directly at 2K resolution — the self-attention operates over 16,384 tokens with TP=4, which fits in HBM on trn2.3xlarge (24 GB/core with LNC=2). +- The VAE decoder exceeds the 5M instruction limit at 2K, so it is compiled at 1024x1024 and the 256x256 latent is decoded with 4 overlapping tiles (128x128 each, 16px overlap). +- No context parallelism needed. Same hardware as 1K. + +**4096x4096 (65,536 tokens):** +- The 65,536-token self-attention exceeds per-core HBM capacity on trn2.3xlarge even with TP=4. +- Solution: **Context Parallelism (CP=4)** splits the sequence across 4 groups, giving each shard 16,384 tokens (identical to the working 2K case). +- Configuration: `TP=4, CP=4, world_size=16` on trn2.48xlarge using 16 of 64 logical cores. +- The VAE decoder is compiled at 1024x1024 and decodes the 512x512 latent with 25 overlapping tiles. +- **CRITICAL**: Must set `NEURON_RT_VISIBLE_CORES=0-15` to prevent the runtime from detecting all 64 cores and creating a 64-rank collective communicator (which deadlocks). + +### High-Resolution Usage + +```bash +# 2K on trn2.3xlarge (same instance as 1K): +python src/generate_flux_lite_highres.py \ + --checkpoint_dir /shared/flux1-lite-8b \ + --height 2048 --width 2048 \ + --save_image --save_results + +# 4K on trn2.48xlarge (requires NEURON_RT_VISIBLE_CORES): +NEURON_RT_VISIBLE_CORES=0-15 python src/generate_flux_lite_highres.py \ + --checkpoint_dir /shared/flux1-lite-8b \ + --height 4096 --width 4096 \ + --save_image --save_results +``` + +### 4K Timing Breakdown + +| Phase | Time | Notes | +|-------|------|-------| +| Compilation (from scratch) | ~39 min | Backbone 25.8 min + VAE 7.3 min + encoders <1 min | +| Compilation (from cache) | 0s | NEFFs cached in compile_workdir | +| Model loading (cold) | ~40 min | 8B params across 16 cores from disk | +| Model loading (warm) | ~11s | NEFFs already in device memory | +| Backbone (25 steps) | 99.5s | 3.98s/step | +| Tiled VAE decode (25 tiles) | 7.3s | 0.29s/tile | +| **Total generation** | **107.25s** | Steady state (after warmup) | + +## Usage + +```python +import torch +from neuronx_distributed_inference.models.diffusers.flux.application import ( + NeuronFluxApplication, + create_flux_config, + get_flux_parallelism_config, +) + +MODEL_PATH = "/shared/flux1-lite-8b/" +COMPILE_DIR = "/tmp/flux-lite/compiled/" + +# Configure (reads num_layers=8 from model's config.json automatically) +world_size = get_flux_parallelism_config(backbone_tp_degree=4) +clip_cfg, t5_cfg, backbone_cfg, decoder_cfg = create_flux_config( + MODEL_PATH, world_size, backbone_tp_degree=4, + dtype=torch.bfloat16, height=1024, width=1024, +) + +# Create application +app = NeuronFluxApplication( + model_path=MODEL_PATH, + text_encoder_config=clip_cfg, + text_encoder2_config=t5_cfg, + backbone_config=backbone_cfg, + decoder_config=decoder_cfg, + height=1024, width=1024, +) + +# Compile + load +app.compile(COMPILE_DIR) +app.load(COMPILE_DIR) + +# Generate +image = app( + "A cat holding a sign that says hello world", + height=1024, width=1024, + guidance_scale=3.5, + num_inference_steps=25, +).images[0] +image.save("output.png") +``` + +Or use the provided script: + +```bash +python src/generate_flux_lite.py \ + --checkpoint_dir /shared/flux1-lite-8b \ + --compile_workdir /tmp/flux-lite/compiled/ \ + --prompt "A cat holding a sign that says hello world" \ + --height 1024 --width 1024 \ + --num_inference_steps 25 \ + --save_image +``` + +## Setup + +```bash +# Activate NxDI environment +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +# Install diffusers (not pre-installed in NxDI venv) +pip install diffusers transformers accelerate sentencepiece protobuf + +# Download model (requires HuggingFace token with access) +huggingface-cli download Freepik/flux.1-lite-8B-alpha \ + --local-dir /shared/flux1-lite-8b +``` + +## Compatibility Matrix + +| Instance | Resolution | SDK 2.29 | SDK 2.30 | +|----------|-----------|----------|----------| +| trn2.3xlarge (LNC=2, TP=4) | 1024x1024 | VALIDATED | VALIDATED | +| trn2.3xlarge (LNC=2, TP=4) | 2048x2048 | VALIDATED | VALIDATED | +| trn2.48xlarge (LNC=2, TP=4, CP=4) | 4096x4096 | Not tested | VALIDATED | + +## Example Checkpoints + +* [Freepik/flux.1-lite-8B-alpha](https://huggingface.co/Freepik/flux.1-lite-8B-alpha) + +## Testing Instructions + +```bash +# Set model path +export FLUX_LITE_MODEL_PATH=/shared/flux1-lite-8b/ + +# Run with pytest +cd contrib/models/flux1-lite-8b/ +pytest test/integration/test_model.py -v + +# Or standalone +python test/integration/test_model.py +``` + +## Known Issues + +- The NxDI venv (`/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/`) does not include `diffusers` by default. Install it with pip before running. +- `attention_cte` kernel warnings about batch size x seqlen_q x seqlen_k appear during inference. These are informational and do not affect output quality. + +## Sample Output + +![FLUX.1-lite output](samples/flux_lite_cat_hello_world.png) + +*"A cat holding a sign that says hello world" -- 1024x1024, 25 steps, guidance_scale=3.5* diff --git a/contrib/models/flux1-lite-8b/samples/flux_lite_cat_hello_world.png b/contrib/models/flux1-lite-8b/samples/flux_lite_cat_hello_world.png new file mode 100644 index 00000000..c7bca682 Binary files /dev/null and b/contrib/models/flux1-lite-8b/samples/flux_lite_cat_hello_world.png differ diff --git a/contrib/models/flux1-lite-8b/src/__init__.py b/contrib/models/flux1-lite-8b/src/__init__.py new file mode 100644 index 00000000..37470bb6 --- /dev/null +++ b/contrib/models/flux1-lite-8b/src/__init__.py @@ -0,0 +1,3 @@ +# NxDI FLUX.1-lite-8B-alpha Diffusion Model +# Demonstrates that FLUX.1-lite runs natively on NxDI's first-party FLUX.1 implementation +# with no code modifications -- only different model weights. diff --git a/contrib/models/flux1-lite-8b/src/generate_flux_lite.py b/contrib/models/flux1-lite-8b/src/generate_flux_lite.py new file mode 100644 index 00000000..349ea243 --- /dev/null +++ b/contrib/models/flux1-lite-8b/src/generate_flux_lite.py @@ -0,0 +1,147 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +FLUX.1-lite-8B-alpha generation script for AWS Neuron. + +FLUX.1-lite-8B-alpha (Freepik) is architecturally identical to FLUX.1-dev with a +reduced backbone: 8 double-stream MMDiT blocks instead of 19. It uses the same +CLIP + T5-XXL text encoders, FluxPipeline, VAE, scheduler, and RoPE configuration. + +Because of this architectural compatibility, FLUX.1-lite runs natively on NxDI's +first-party FLUX.1 implementation with no code modifications. The NxDI FLUX.1 +application reads `num_layers` and `num_single_layers` from the model's config.json +at runtime, so it automatically adapts to FLUX.1-lite's configuration: + - num_layers: 8 (vs 19 in FLUX.1-dev) + - num_single_layers: 38 (same as FLUX.1-dev) + +Usage: + # Download the model (requires HuggingFace access): + huggingface-cli download Freepik/flux.1-lite-8B-alpha --local-dir /shared/flux1-lite-8b + + # Generate an image: + python generate_flux_lite.py \\ + --checkpoint_dir /shared/flux1-lite-8b \\ + --compile_workdir /tmp/flux-lite/compiled/ \\ + --prompt "A cat holding a sign that says hello world" \\ + --height 1024 --width 1024 \\ + --num_inference_steps 25 \\ + --save_image + +Requirements: + pip install diffusers transformers accelerate sentencepiece protobuf +""" + +import argparse +import time + +import torch +from neuronx_distributed_inference.models.diffusers.flux.application import ( + NeuronFluxApplication, + create_flux_config, + get_flux_parallelism_config, +) +from neuronx_distributed_inference.utils.random import set_random_seed + +set_random_seed(0) + +DEFAULT_CKPT_DIR = "/shared/flux1-lite-8b/" +DEFAULT_COMPILE_DIR = "/tmp/flux-lite/compiled/" + + +def run_generate(args): + print(f"FLUX.1-lite-8B generation with args: {args}") + + backbone_tp_degree = args.backbone_tp_degree if args.backbone_tp_degree else 4 + world_size = get_flux_parallelism_config(backbone_tp_degree) + dtype = torch.bfloat16 + + clip_config, t5_config, backbone_config, decoder_config = create_flux_config( + args.checkpoint_dir, + world_size, + backbone_tp_degree, + dtype, + args.height, + args.width, + ) + + flux_app = NeuronFluxApplication( + model_path=args.checkpoint_dir, + text_encoder_config=clip_config, + text_encoder2_config=t5_config, + backbone_config=backbone_config, + decoder_config=decoder_config, + height=args.height, + width=args.width, + ) + + print("Compiling model...") + compile_start = time.time() + flux_app.compile(args.compile_workdir) + compile_time = time.time() - compile_start + print(f"Compilation completed in {compile_time:.1f}s") + + flux_app.load(args.compile_workdir) + + # Warmup + print("Warming up...") + for _ in range(args.warmup_rounds): + flux_app( + args.prompt, + height=args.height, + width=args.width, + guidance_scale=args.guidance_scale, + num_inference_steps=args.num_inference_steps, + ).images[0] + + # Generate + total_time = 0 + for i in range(args.num_images): + start = time.time() + image = flux_app( + args.prompt, + height=args.height, + width=args.width, + guidance_scale=args.guidance_scale, + num_inference_steps=args.num_inference_steps, + ).images[0] + gen_time = time.time() - start + total_time += gen_time + + if args.save_image: + filename = f"flux_lite_output_{i + 1}.png" + image.save(filename) + print(f"Image {i + 1} saved to {filename} in {gen_time:.2f}s") + else: + print(f"Image {i + 1} generated in {gen_time:.2f}s") + + avg_time = total_time / args.num_images + steps_per_sec = args.num_inference_steps / avg_time + print(f"\nResults:") + print(f" Average generation time: {avg_time:.2f}s") + print(f" Pipeline steps/sec: {steps_per_sec:.2f}") + print(f" Compilation time: {compile_time:.1f}s") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FLUX.1-lite-8B on AWS Neuron (NxDI)") + parser.add_argument( + "-p", "--prompt", type=str, default="A cat holding a sign that says hello world" + ) + parser.add_argument("-hh", "--height", type=int, default=1024) + parser.add_argument("-w", "--width", type=int, default=1024) + parser.add_argument("-n", "--num_inference_steps", type=int, default=25) + parser.add_argument("-g", "--guidance_scale", type=float, default=3.5) + parser.add_argument("-c", "--checkpoint_dir", type=str, default=DEFAULT_CKPT_DIR) + parser.add_argument("--compile_workdir", type=str, default=DEFAULT_COMPILE_DIR) + parser.add_argument("--num_images", type=int, default=3) + parser.add_argument("--warmup_rounds", type=int, default=5) + parser.add_argument("--save_image", action="store_true") + parser.add_argument( + "--backbone_tp_degree", + type=int, + default=None, + help="Tensor parallelism degree (default: 4 for trn2.3xlarge)", + ) + args = parser.parse_args() + run_generate(args) diff --git a/contrib/models/flux1-lite-8b/src/generate_flux_lite_highres.py b/contrib/models/flux1-lite-8b/src/generate_flux_lite_highres.py new file mode 100644 index 00000000..9be237d7 --- /dev/null +++ b/contrib/models/flux1-lite-8b/src/generate_flux_lite_highres.py @@ -0,0 +1,377 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +FLUX.1-lite-8B high-resolution generation (2048x2048, 4096x4096) on AWS Neuron. + +This script extends the base FLUX.1-lite generation to high resolutions that are +NOT supported by the original model weights or standard inference pipelines: + + - 2048x2048 (16,384 tokens): Backbone compiles natively at TP=4 on trn2.3xlarge. + VAE decoder exceeds instruction limit at 2K, so tiled VAE decode is used. + + - 4096x4096 (65,536 tokens): Requires context parallelism (TP=4, CP=4) on + trn2.48xlarge. The backbone processes all 65,536 tokens natively with the + sequence split across 4 CP shards (16,384 tokens each). VAE decoder uses + tiled decode with 25 overlapping tiles. + +IMPORTANT: The original FLUX.1-lite-8B model was trained at 1024x1024. Generating +at 2K/4K is an extrapolation — image quality may vary compared to the native +resolution. This capability is useful for customers who have fine-tuned their own +models at higher resolutions or want to evaluate the architecture's scaling behavior. + +Instance requirements: + - 2048x2048: trn2.3xlarge (LNC=2, TP=4) — same as 1K + - 4096x4096: trn2.48xlarge (LNC=2, 16 of 64 cores) + MUST set: NEURON_RT_VISIBLE_CORES=0-15 + +Usage: + # 2K on trn2.3xlarge: + python generate_flux_lite_highres.py \\ + --checkpoint_dir /shared/flux1-lite-8b \\ + --height 2048 --width 2048 + + # 4K on trn2.48xlarge (requires NEURON_RT_VISIBLE_CORES=0-15): + NEURON_RT_VISIBLE_CORES=0-15 python generate_flux_lite_highres.py \\ + --checkpoint_dir /shared/flux1-lite-8b \\ + --height 4096 --width 4096 +""" + +import argparse +import json +import os +import time + +import torch +from neuronx_distributed_inference.models.diffusers.flux.application import ( + NeuronFluxApplication, + create_flux_config, + get_flux_parallelism_config, +) +from neuronx_distributed_inference.utils.random import set_random_seed + +set_random_seed(0) + +# VAE decoder is always compiled at 1024x1024 output (128x128 latent). +# Higher resolutions use tiled decode over the compiled 1K decoder. +VAE_COMPILE_SIZE = 1024 +LATENT_TILE_SIZE = VAE_COMPILE_SIZE // 8 # 128 latent pixels +LATENT_OVERLAP = 16 # overlap in latent space (128 pixels in image space) +LATENT_STRIDE = LATENT_TILE_SIZE - LATENT_OVERLAP # 112 + + +def get_parallelism_config(height, width, backbone_tp_degree): + """Determine world_size and whether context parallelism is needed. + + Returns: + (world_size, context_parallel_enabled) + """ + tokens = (height * width) // 256 # patch_size=2 in 8x downsampled latent + + if tokens <= 16384: + # 1K (4096 tokens) or 2K (16384 tokens): standard TP, no CP needed + world_size = get_flux_parallelism_config(backbone_tp_degree) + return world_size, False + else: + # 4K (65536 tokens): need context parallelism to split the sequence + # CP = world_size / tp_degree. For 4K we want each shard to have + # 16384 tokens (same as working 2K), so CP = 65536 / 16384 = 4. + cp_degree = tokens // 16384 + world_size = backbone_tp_degree * cp_degree + return world_size, True + + +def setup_tiled_vae_decode(flux_app): + """Monkey-patch the VAE decoder with a tiled decode implementation. + + The FLUX VAE decoder is compiled at 1024x1024 output (128x128 latent). + For higher resolutions, the full latent is decoded in overlapping tiles + with Gaussian-weighted blending for seamless stitching. + """ + original_vae_decode = flux_app.pipe.vae.decode + + def tiled_vae_decode(latent, **kwargs): + B, C, H, W = latent.shape + + # Extract return_dict from kwargs (pipeline passes return_dict=False) + return_dict = kwargs.pop("return_dict", True) + + if H <= LATENT_TILE_SIZE and W <= LATENT_TILE_SIZE: + return original_vae_decode(latent, return_dict=return_dict, **kwargs) + + # Compute tile grid positions + row_starts = list(range(0, H - LATENT_TILE_SIZE + 1, LATENT_STRIDE)) + if row_starts[-1] + LATENT_TILE_SIZE < H: + row_starts.append(H - LATENT_TILE_SIZE) + col_starts = list(range(0, W - LATENT_TILE_SIZE + 1, LATENT_STRIDE)) + if col_starts[-1] + LATENT_TILE_SIZE < W: + col_starts.append(W - LATENT_TILE_SIZE) + + n_tiles = len(row_starts) * len(col_starts) + print( + f" Tiled VAE decode: {H}x{W} latent -> " + f"{len(row_starts)}x{len(col_starts)} = {n_tiles} tiles" + ) + + # Output image dimensions + out_h = H * 8 + out_w = W * 8 + output = torch.zeros(B, 3, out_h, out_w) + weight = torch.zeros(1, 1, out_h, out_w) + + # Gaussian blend weight for seamless tile stitching + tile_out_size = LATENT_TILE_SIZE * 8 + y = torch.linspace(-1, 1, tile_out_size) + x = torch.linspace(-1, 1, tile_out_size) + yy, xx = torch.meshgrid(y, x, indexing="ij") + gauss = torch.exp(-(xx**2 + yy**2) * 3.0).unsqueeze(0).unsqueeze(0) + + tile_idx = 0 + vae_start = time.time() + for r in row_starts: + for c in col_starts: + tile_idx += 1 + tile_latent = latent[ + :, :, r : r + LATENT_TILE_SIZE, c : c + LATENT_TILE_SIZE + ].contiguous() + + # Decode tile + tile_result = original_vae_decode( + tile_latent, return_dict=False, **kwargs + ) + if isinstance(tile_result, tuple): + tile_pixels = tile_result[0] + elif hasattr(tile_result, "sample"): + tile_pixels = tile_result.sample + else: + tile_pixels = tile_result + + tile_pixels = tile_pixels.detach().cpu().float() + + # Accumulate with Gaussian weighting + out_r = r * 8 + out_c = c * 8 + output[ + :, :, out_r : out_r + tile_out_size, out_c : out_c + tile_out_size + ] += tile_pixels * gauss + weight[ + :, :, out_r : out_r + tile_out_size, out_c : out_c + tile_out_size + ] += gauss + + if tile_idx % 5 == 0 or tile_idx == n_tiles: + elapsed = time.time() - vae_start + print( + f" Tile {tile_idx}/{n_tiles} done " + f"({elapsed:.1f}s, {elapsed / tile_idx:.2f}s/tile)" + ) + + output = output / weight.clamp(min=1e-8) + + vae_total = time.time() - vae_start + print( + f" Tiled VAE decode complete: {vae_total:.1f}s " + f"({vae_total / n_tiles:.2f}s/tile)" + ) + + if return_dict: + from diffusers.models.autoencoders.vae import DecoderOutput + + return DecoderOutput(sample=output) + else: + return (output,) + + flux_app.pipe.vae.decode = tiled_vae_decode + return n_tiles_needed(flux_app.height, flux_app.width) + + +def n_tiles_needed(height, width): + """Compute the number of VAE decode tiles needed for a given resolution.""" + latent_h = height // 8 + latent_w = width // 8 + if latent_h <= LATENT_TILE_SIZE and latent_w <= LATENT_TILE_SIZE: + return 1 + row_starts = list(range(0, latent_h - LATENT_TILE_SIZE + 1, LATENT_STRIDE)) + if row_starts[-1] + LATENT_TILE_SIZE < latent_h: + row_starts.append(latent_h - LATENT_TILE_SIZE) + col_starts = list(range(0, latent_w - LATENT_TILE_SIZE + 1, LATENT_STRIDE)) + if col_starts[-1] + LATENT_TILE_SIZE < latent_w: + col_starts.append(latent_w - LATENT_TILE_SIZE) + return len(row_starts) * len(col_starts) + + +def run_generate(args): + height = args.height + width = args.width + backbone_tp_degree = args.backbone_tp_degree or 4 + tokens = (height * width) // 256 + + # Determine parallelism + world_size, context_parallel = get_parallelism_config( + height, width, backbone_tp_degree + ) + cp_degree = world_size // backbone_tp_degree + + print(f"FLUX.1-lite-8B High-Resolution Generation") + print(f" Resolution: {height}x{width} ({tokens} tokens)") + print(f" TP={backbone_tp_degree}, CP={cp_degree}, world_size={world_size}") + if context_parallel: + print(f" Tokens per CP shard: {tokens // cp_degree}") + print(f" VAE decoder compiled at {VAE_COMPILE_SIZE}x{VAE_COMPILE_SIZE}") + + # Check NEURON_RT_VISIBLE_CORES for 4K + if world_size > 8 and "NEURON_RT_VISIBLE_CORES" not in os.environ: + print( + "\nWARNING: NEURON_RT_VISIBLE_CORES is not set. For world_size > 8 on " + "trn2.48xlarge, you MUST set NEURON_RT_VISIBLE_CORES=0-15 to prevent " + "the runtime from creating 64-rank collectives that cause deadlock." + ) + print("Set: export NEURON_RT_VISIBLE_CORES=0-15") + return + + dtype = torch.bfloat16 + + # Create configs + clip_config, t5_config, backbone_config, decoder_config = create_flux_config( + args.checkpoint_dir, + world_size, + backbone_tp_degree, + dtype, + height, + width, + context_parallel_enabled=context_parallel, + ) + + # Override decoder to compile at 1K (will tile for higher resolutions) + needs_tiled_vae = (height > VAE_COMPILE_SIZE) or (width > VAE_COMPILE_SIZE) + if needs_tiled_vae: + decoder_config.height = VAE_COMPILE_SIZE + decoder_config.width = VAE_COMPILE_SIZE + vae_tiles = n_tiles_needed(height, width) + print(f" VAE tiling: {vae_tiles} tiles with {LATENT_OVERLAP}px overlap") + + # Create application + flux_app = NeuronFluxApplication( + model_path=args.checkpoint_dir, + text_encoder_config=clip_config, + text_encoder2_config=t5_config, + backbone_config=backbone_config, + decoder_config=decoder_config, + height=height, + width=width, + ) + + # Compile + print("\nCompiling model...") + compile_start = time.time() + flux_app.compile(args.compile_workdir) + compile_time = time.time() - compile_start + print(f"Compilation: {compile_time:.1f}s ({compile_time / 60:.1f} min)") + + # Load + print("Loading model...") + load_start = time.time() + flux_app.load(args.compile_workdir) + load_time = time.time() - load_start + print(f"Model loaded: {load_time:.1f}s ({load_time / 60:.1f} min)") + + # Setup tiled VAE decode if needed + if needs_tiled_vae: + setup_tiled_vae_decode(flux_app) + print("Tiled VAE decode enabled") + + # Warmup + print(f"\nGenerating warmup image...") + t0 = time.time() + image = flux_app( + args.prompt, + height=height, + width=width, + guidance_scale=args.guidance_scale, + num_inference_steps=args.num_inference_steps, + ).images[0] + warmup_time = time.time() - t0 + print(f"Warmup: {warmup_time:.2f}s") + + # Benchmark + print(f"\nBenchmarking ({args.num_images} rounds)...") + times = [] + for i in range(args.num_images): + t0 = time.time() + image = flux_app( + args.prompt, + height=height, + width=width, + guidance_scale=args.guidance_scale, + num_inference_steps=args.num_inference_steps, + ).images[0] + elapsed = time.time() - t0 + times.append(elapsed) + print(f" Image {i + 1}: {elapsed:.2f}s") + + avg_time = sum(times) / len(times) + print(f"\nResults:") + print(f" Resolution: {height}x{width}") + print(f" Steps: {args.num_inference_steps}") + print(f" Average: {avg_time:.2f}s/image") + print(f" Compilation: {compile_time:.1f}s") + print(f" Model load: {load_time:.1f}s") + + if args.save_image: + filename = f"flux_lite_{height}x{width}_output.png" + image.save(filename) + print(f" Saved: {filename} ({image.size})") + + if args.save_results: + results = { + "resolution": f"{height}x{width}", + "tokens": tokens, + "backbone_tp": backbone_tp_degree, + "cp_degree": cp_degree, + "world_size": world_size, + "context_parallel": context_parallel, + "num_steps": args.num_inference_steps, + "guidance_scale": args.guidance_scale, + "warmup_time_s": round(warmup_time, 2), + "benchmark_times_s": [round(t, 2) for t in times], + "average_time_s": round(avg_time, 2), + "compile_time_s": round(compile_time, 1), + "model_load_time_s": round(load_time, 1), + "vae_tiles": vae_tiles if needs_tiled_vae else 1, + } + results_file = f"flux_lite_{height}x{width}_results.json" + with open(results_file, "w") as f: + json.dump(results, f, indent=2) + print(f" Results: {results_file}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="FLUX.1-lite-8B high-resolution generation on AWS Neuron" + ) + parser.add_argument( + "-p", + "--prompt", + type=str, + default="A cat holding a sign that says hello world", + ) + parser.add_argument("-hh", "--height", type=int, default=2048) + parser.add_argument("-w", "--width", type=int, default=2048) + parser.add_argument("-n", "--num_inference_steps", type=int, default=25) + parser.add_argument("-g", "--guidance_scale", type=float, default=3.5) + parser.add_argument( + "-c", "--checkpoint_dir", type=str, default="/shared/flux1-lite-8b/" + ) + parser.add_argument( + "--compile_workdir", type=str, default="/tmp/flux-lite-highres/compiled/" + ) + parser.add_argument("--num_images", type=int, default=3) + parser.add_argument("--save_image", action="store_true") + parser.add_argument("--save_results", action="store_true") + parser.add_argument( + "--backbone_tp_degree", + type=int, + default=None, + help="Tensor parallelism degree (default: 4)", + ) + args = parser.parse_args() + run_generate(args) diff --git a/contrib/models/flux1-lite-8b/test/__init__.py b/contrib/models/flux1-lite-8b/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/flux1-lite-8b/test/integration/__init__.py b/contrib/models/flux1-lite-8b/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/flux1-lite-8b/test/integration/test_model.py b/contrib/models/flux1-lite-8b/test/integration/test_model.py new file mode 100644 index 00000000..5b27978d --- /dev/null +++ b/contrib/models/flux1-lite-8b/test/integration/test_model.py @@ -0,0 +1,223 @@ +""" +Integration tests for FLUX.1-lite-8B-alpha on Neuron. + +FLUX.1-lite-8B-alpha is architecturally identical to FLUX.1-dev with fewer +double-stream blocks (8 vs 19). It runs natively on NxDI's first-party +FLUX.1 implementation with no code modifications. + +Tests: +1. test_smoke_pipeline_loads: Pipeline loads without errors +2. test_generation_produces_image: Generates a 1024x1024 image +3. test_warm_generation_time: Warm generation < 15s (25 steps, no CFG) + +Requirements: + - trn2.3xlarge with LNC=2 (4 logical cores) + - Neuron SDK 2.29+ + - diffusers >= 0.37.1, transformers, sentencepiece, protobuf + - Model downloaded to /shared/flux1-lite-8b/ or $FLUX_LITE_MODEL_PATH + +Run: + pytest test_model.py -v + # Or standalone: + python test_model.py +""" + +import gc +import os +import sys +import time + +import numpy as np +import pytest +import torch + +MODEL_PATH = os.environ.get("FLUX_LITE_MODEL_PATH", "/shared/flux1-lite-8b/") +COMPILE_DIR = os.environ.get("FLUX_LITE_COMPILE_DIR", "/tmp/flux_lite_test/") +TP_DEGREE = int(os.environ.get("FLUX_LITE_TP_DEGREE", "4")) +HEIGHT = 1024 +WIDTH = 1024 +NUM_STEPS = 25 +GUIDANCE_SCALE = 3.5 +PROMPT = "A cat holding a sign that says hello world" + + +@pytest.fixture(scope="module") +def neuron_app(): + """Create, compile, and load FLUX.1-lite using NxDI's FLUX.1 application.""" + from neuronx_distributed_inference.models.diffusers.flux.application import ( + NeuronFluxApplication, + create_flux_config, + get_flux_parallelism_config, + ) + from neuronx_distributed_inference.utils.random import set_random_seed + + set_random_seed(0) + + world_size = get_flux_parallelism_config(TP_DEGREE) + dtype = torch.bfloat16 + + clip_config, t5_config, backbone_config, decoder_config = create_flux_config( + MODEL_PATH, + world_size, + TP_DEGREE, + dtype, + HEIGHT, + WIDTH, + ) + + app = NeuronFluxApplication( + model_path=MODEL_PATH, + text_encoder_config=clip_config, + text_encoder2_config=t5_config, + backbone_config=backbone_config, + decoder_config=decoder_config, + height=HEIGHT, + width=WIDTH, + ) + + app.compile(COMPILE_DIR) + app.load(COMPILE_DIR) + + # Warmup + app( + PROMPT, + height=HEIGHT, + width=WIDTH, + guidance_scale=GUIDANCE_SCALE, + num_inference_steps=NUM_STEPS, + ) + + yield app + + del app + gc.collect() + + +def test_smoke_pipeline_loads(neuron_app): + """Pipeline loads without errors and has required components.""" + assert neuron_app is not None + assert neuron_app.pipe is not None + assert neuron_app.pipe.transformer is not None + assert neuron_app.pipe.text_encoder is not None + assert neuron_app.pipe.text_encoder_2 is not None + assert neuron_app.pipe.vae is not None + + +def test_generation_produces_image(neuron_app): + """Generates an image at the expected resolution.""" + result = neuron_app( + PROMPT, + height=HEIGHT, + width=WIDTH, + guidance_scale=GUIDANCE_SCALE, + num_inference_steps=NUM_STEPS, + ) + + assert result is not None + assert hasattr(result, "images") + assert len(result.images) == 1 + + image = result.images[0] + assert image.size == (WIDTH, HEIGHT), ( + f"Expected ({WIDTH}, {HEIGHT}), got {image.size}" + ) + + # Verify the image has reasonable pixel values (not blank/noise) + img_array = np.array(image) + assert img_array.shape == (HEIGHT, WIDTH, 3) + assert img_array.std() > 10, "Image appears blank or uniform" + + # Save for inspection + os.makedirs(os.path.join(COMPILE_DIR, "test_outputs"), exist_ok=True) + image.save(os.path.join(COMPILE_DIR, "test_outputs", "test_generation.png")) + print(f"Image saved, pixel std={img_array.std():.1f}") + + +def test_warm_generation_time(neuron_app): + """Warm generation should complete in reasonable time.""" + t0 = time.time() + neuron_app( + PROMPT, + height=HEIGHT, + width=WIDTH, + guidance_scale=GUIDANCE_SCALE, + num_inference_steps=NUM_STEPS, + ) + elapsed = time.time() - t0 + print(f"Warm generation time: {elapsed:.2f}s") + + # FLUX.1-lite with 25 steps, no CFG, TP=4: expect ~6s, allow up to 15s + assert elapsed < 15, f"Generation took {elapsed:.2f}s, expected < 15s" + + +# Standalone runner +if __name__ == "__main__": + from neuronx_distributed_inference.models.diffusers.flux.application import ( + NeuronFluxApplication, + create_flux_config, + get_flux_parallelism_config, + ) + from neuronx_distributed_inference.utils.random import set_random_seed + + set_random_seed(0) + + print("=" * 60) + print("FLUX.1-lite-8B-alpha Integration Tests") + print("=" * 60) + + world_size = get_flux_parallelism_config(TP_DEGREE) + + clip_config, t5_config, backbone_config, decoder_config = create_flux_config( + MODEL_PATH, + world_size, + TP_DEGREE, + torch.bfloat16, + HEIGHT, + WIDTH, + ) + + print(f"\nModel config:") + print(f" num_layers (double blocks): {backbone_config.num_layers}") + print(f" num_single_layers: {backbone_config.num_single_layers}") + print(f" TP degree: {TP_DEGREE}") + + app = NeuronFluxApplication( + model_path=MODEL_PATH, + text_encoder_config=clip_config, + text_encoder2_config=t5_config, + backbone_config=backbone_config, + decoder_config=decoder_config, + height=HEIGHT, + width=WIDTH, + ) + + print("\n[1/5] Compiling...") + t0 = time.time() + app.compile(COMPILE_DIR) + print(f" Compilation: {time.time() - t0:.1f}s") + + print("\n[2/5] Loading...") + app.load(COMPILE_DIR) + + print("\n[3/5] Warmup...") + app( + PROMPT, + height=HEIGHT, + width=WIDTH, + guidance_scale=GUIDANCE_SCALE, + num_inference_steps=NUM_STEPS, + ) + + print("\n[4/5] test_smoke_pipeline_loads") + test_smoke_pipeline_loads(app) + print(" PASSED") + + print("\n[5/5] test_generation_produces_image") + test_generation_produces_image(app) + print(" PASSED") + + print("\n[6/6] test_warm_generation_time") + test_warm_generation_time(app) + print(" PASSED") + + print("\nAll tests passed!")