From 378c33837eaac63bf2caab3860f0f61e05674320 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 1 Nov 2025 13:34:39 +0100 Subject: [PATCH 1/2] initial onnx scripts --- scripts/onnx/export_chronos2_to_onnx.py | 853 ++++++++++++++++++++++++ scripts/onnx/fix_onnx_model.py | 196 ++++++ scripts/onnx/quantize_chronos2.py | 253 +++++++ 3 files changed, 1302 insertions(+) create mode 100755 scripts/onnx/export_chronos2_to_onnx.py create mode 100644 scripts/onnx/fix_onnx_model.py create mode 100644 scripts/onnx/quantize_chronos2.py diff --git a/scripts/onnx/export_chronos2_to_onnx.py b/scripts/onnx/export_chronos2_to_onnx.py new file mode 100755 index 00000000..f9cea2e8 --- /dev/null +++ b/scripts/onnx/export_chronos2_to_onnx.py @@ -0,0 +1,853 @@ +#!/usr/bin/env python3 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Export Chronos-2 models to ONNX format for use with transformers.js + +This script: +1. Loads a pretrained Chronos-2 model +2. Exports it to ONNX format with proper dynamic axes +3. Validates the ONNX export by comparing outputs with PyTorch +4. Optionally quantizes the model for smaller size + +Usage: + python export_chronos2_to_onnx.py \ + --model_id amazon/chronos-2-small \ + --output_dir ./chronos2-small-onnx \ + --validate + +Requirements: + pip install torch onnx onnxruntime transformers chronos-forecasting +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict + +import torch +import torch.nn as nn +import numpy as np + +from chronos import Chronos2Pipeline + +# Register custom ONNX symbolic functions for operations that aren't properly mapped +from torch.onnx import register_custom_op_symbolic + + +def asinh_symbolic(g, input): + """Custom ONNX symbolic function for asinh (arcsinh).""" + return g.op("Asinh", input) + + +def sinh_symbolic(g, input): + """Custom ONNX symbolic function for sinh.""" + return g.op("Sinh", input) + + +# Register the symbolic functions for opset 9+ +register_custom_op_symbolic("aten::asinh", asinh_symbolic, 9) +register_custom_op_symbolic("aten::sinh", sinh_symbolic, 9) + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +class Chronos2ONNXWrapper(nn.Module): + """ + Wrapper around Chronos2Model to handle ONNX export. + + This wrapper simplifies the input/output interface for ONNX export + by flattening the input dictionary structure. + """ + + def __init__(self, chronos2_model): + super().__init__() + self.model = chronos2_model + + def forward( + self, + context: torch.Tensor, + group_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, + future_covariates: torch.Tensor | None = None, + num_output_patches: int = 1, + ): + """ + Forward pass compatible with ONNX export. + + Args: + context: Historical context tensor of shape (batch_size, context_length) + group_ids: Group IDs tensor of shape (batch_size,) + attention_mask: Optional attention mask of shape (batch_size, context_length) + future_covariates: Optional future covariates of shape (batch_size, future_length) + num_output_patches: Number of output patches to generate (int, will be symbolic in ONNX) + + Returns: + quantile_preds: Tensor of shape (batch_size, num_quantiles, prediction_length) + """ + # Prepare kwargs - num_output_patches is now directly an int that ONNX can trace symbolically + kwargs = { + "context": context, + "group_ids": group_ids, + "num_output_patches": num_output_patches, + } + + if attention_mask is not None: + kwargs["context_mask"] = attention_mask + + if future_covariates is not None: + kwargs["future_covariates"] = future_covariates + + # Run model forward pass + outputs = self.model(**kwargs) + + # Return only the quantile predictions (drop loss and attention weights) + return outputs.quantile_preds + + +def create_dummy_inputs( + batch_size: int = 2, + context_length: int = 512, + num_output_patches: int = 1, + include_future_covariates: bool = False, + output_patch_size: int = 64, + device: str = "cpu", +) -> Dict[str, torch.Tensor]: + """ + Create dummy inputs for ONNX export. + + Args: + batch_size: Batch size + context_length: Length of historical context + num_output_patches: Number of output patches + include_future_covariates: Whether to include future covariates + output_patch_size: Size of each output patch + device: Device to create tensors on + + Returns: + Dictionary of dummy inputs + """ + dummy_inputs = { + "context": torch.randn(batch_size, context_length, device=device, dtype=torch.float32), + "group_ids": torch.arange(batch_size, device=device, dtype=torch.long), + "attention_mask": torch.ones(batch_size, context_length, device=device, dtype=torch.float32), + "num_output_patches": num_output_patches, # int value, will be fixed in ONNX + } + + if include_future_covariates: + future_length = num_output_patches * output_patch_size + dummy_inputs["future_covariates"] = torch.randn(batch_size, future_length, device=device, dtype=torch.float32) + + return dummy_inputs + + +def export_to_onnx( + model_id: str, + output_dir: Path, + opset_version: int = 17, + use_fp16: bool = False, + include_future_covariates: bool = True, + device: str = None, +) -> Path: + """ + Export Chronos-2 model to ONNX format. + + Args: + model_id: HuggingFace model ID or local path + output_dir: Directory to save ONNX model + opset_version: ONNX opset version (17 recommended for best compatibility) + use_fp16: Whether to use FP16 precision + include_future_covariates: Whether to support future covariates in export + device: Device to use ('cuda' or 'cpu') + + Returns: + Path to exported ONNX model + """ + # Auto-detect device if not specified + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + logger.info(f"Loading Chronos-2 model from {model_id}") + + # Load the pipeline and extract the model + # Official model is now available at: https://huggingface.co/amazon/chronos-2 + pipeline = Chronos2Pipeline.from_pretrained(model_id, device_map=device) + + model = pipeline.model + config = model.config + chronos_config = model.chronos_config + + logger.info( + f"Model config: {config.model_type}, d_model={config.d_model}, " + f"num_layers={config.num_layers}, num_heads={config.num_heads}" + ) + logger.info( + f"Chronos config: context_length={chronos_config.context_length}, " + f"output_patch_size={chronos_config.output_patch_size}, " + f"quantiles={chronos_config.quantiles}" + ) + + # Set model to eval mode + model.eval() + + # Convert to FP16 if requested + if use_fp16: + logger.info("Converting model to FP16") + model = model.half() + + # Wrap model for ONNX export + wrapped_model = Chronos2ONNXWrapper(model) + wrapped_model.eval() + + # Create dummy inputs + batch_size = 2 + context_length = min(512, chronos_config.context_length) # Use smaller context for export + # Export with num_output_patches=4 to support up to 64-step predictions (4 * 16 = 64) + # ONNX models have fixed output shapes - transformers.js will truncate to requested prediction_length + # This matches how the original chronos2 Python code works with dynamic num_output_patches + num_output_patches = 4 + + dummy_inputs = create_dummy_inputs( + batch_size=batch_size, + context_length=context_length, + num_output_patches=num_output_patches, + include_future_covariates=include_future_covariates, + output_patch_size=chronos_config.output_patch_size, + device=device, + ) + + # Define dynamic axes for variable batch size and context length + # Note: prediction_length is fixed based on num_output_patches=4 (64 steps) + dynamic_axes = { + "context": {0: "batch_size", 1: "context_length"}, + "group_ids": {0: "batch_size"}, + "attention_mask": {0: "batch_size", 1: "context_length"}, + "quantile_preds": {0: "batch_size"}, # prediction_length (dim 2) is fixed at 64 + } + + if include_future_covariates: + dynamic_axes["future_covariates"] = {0: "batch_size", 1: "future_length"} + + # Prepare ONNX export args based on whether future_covariates are included + if include_future_covariates: + input_names = ["context", "group_ids", "attention_mask", "future_covariates"] + args = ( + dummy_inputs["context"], + dummy_inputs["group_ids"], + dummy_inputs["attention_mask"], + dummy_inputs["future_covariates"], + dummy_inputs["num_output_patches"], # Passed to wrapper but not an ONNX input + ) + else: + input_names = ["context", "group_ids", "attention_mask"] + args = ( + dummy_inputs["context"], + dummy_inputs["group_ids"], + dummy_inputs["attention_mask"], + None, # No future_covariates + dummy_inputs["num_output_patches"], # Passed to wrapper but not an ONNX input + ) + + output_names = ["quantile_preds"] + + # Create output directory + output_dir.mkdir(parents=True, exist_ok=True) + onnx_path = output_dir / "model.onnx" + + logger.info(f"Exporting model to ONNX format at {onnx_path}") + logger.info(f"Dynamic axes: {dynamic_axes}") + + # Export to ONNX + try: + with torch.no_grad(): + # Skip dynamo exporter when using covariates (has dtype issues with embeddings) + # Always use legacy exporter for now as it's more reliable + use_dynamo = False # Disabled due to dtype issues with Gather ops in embeddings + + if use_dynamo and not include_future_covariates: + # Try new dynamo-based exporter first (supports more ops like nanmean) + try: + torch.onnx.export( + wrapped_model, + args, + str(onnx_path), + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + dynamo=True, # Use new PyTorch 2.x+ exporter + verbose=False, + ) + logger.info("Used dynamo-based ONNX exporter") + except Exception as dynamo_error: + logger.warning(f"Dynamo exporter failed ({dynamo_error}), trying legacy exporter...") + use_dynamo = False + + if not use_dynamo: + # Use legacy exporter (more reliable for embeddings) + logger.info("Using legacy TorchScript-based ONNX exporter") + torch.onnx.export( + wrapped_model, + args, + str(onnx_path), + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=opset_version, + do_constant_folding=True, + export_params=True, + verbose=False, + ) + logger.info("Used legacy TorchScript-based ONNX exporter") + logger.info(f"Successfully exported model to {onnx_path}") + except Exception as e: + logger.error(f"Failed to export model to ONNX: {e}") + raise + + # Save config files + config_path = output_dir / "config.json" + config.save_pretrained(output_dir) + logger.info(f"Saved config to {config_path}") + + # Save generation config if it exists + if hasattr(pipeline, "generation_config"): + generation_config_path = output_dir / "generation_config.json" + pipeline.generation_config.save_pretrained(output_dir) + logger.info(f"Saved generation config to {generation_config_path}") + + return onnx_path + + +def quantize_model(onnx_path: Path) -> Path: + """ + Quantize the ONNX model to INT8. + + Args: + onnx_path: Path to the FP32 ONNX model + + Returns: + Path to the quantized model + """ + try: + from onnxruntime.quantization import quantize_dynamic, QuantType + except ImportError: + logger.error("onnxruntime not installed. Install with: pip install onnxruntime") + raise + + quantized_path = onnx_path.parent / "model_quantized.onnx" + + logger.info("Quantizing model to INT8...") + logger.info(f" Input: {onnx_path}") + logger.info(f" Output: {quantized_path}") + + quantize_dynamic( + model_input=str(onnx_path), + model_output=str(quantized_path), + weight_type=QuantType.QInt8, + ) + + # Compare sizes + original_size = onnx_path.stat().st_size / (1024**2) # MB + quantized_size = quantized_path.stat().st_size / (1024**2) # MB + reduction = (1 - quantized_size / original_size) * 100 + + logger.info(f" Original: {original_size:.1f} MB") + logger.info(f" Quantized: {quantized_size:.1f} MB") + logger.info(f" Reduction: {reduction:.1f}%") + + return quantized_path + + +def setup_transformersjs_structure(output_dir: Path): + """ + Create transformers.js-compatible directory structure. + + Creates: + - onnx/ directory with symlinks to model files + - generation_config.json if missing + """ + import json + import os + + logger.info("Setting up transformers.js directory structure...") + + # Create onnx/ subdirectory + onnx_dir = output_dir / "onnx" + onnx_dir.mkdir(exist_ok=True) + + # Create symlinks for encoder/decoder (transformers.js expects T5-style split) + output_dir / "model.onnx" + encoder_link = onnx_dir / "encoder_model.onnx" + decoder_link = onnx_dir / "decoder_model_merged.onnx" + + # Remove existing symlinks if they exist + if encoder_link.exists() or encoder_link.is_symlink(): + encoder_link.unlink() + if decoder_link.exists() or decoder_link.is_symlink(): + decoder_link.unlink() + + # Create new symlinks + os.symlink("../model.onnx", encoder_link) + os.symlink("../model.onnx", decoder_link) + + logger.info(f" Created {encoder_link}") + logger.info(f" Created {decoder_link}") + + # Create minimal generation_config.json if missing + generation_config_path = output_dir / "generation_config.json" + if not generation_config_path.exists(): + generation_config = {"_from_model_config": True, "transformers_version": "4.36.0"} + with open(generation_config_path, "w") as f: + json.dump(generation_config, f, indent=2) + logger.info(f" Created {generation_config_path}") + + +def generate_readme(output_dir: Path, model_id: str, quantized: bool = False): + """ + Generate README.md with model card for Hub. + + Args: + output_dir: Output directory + model_id: Original model ID + quantized: Whether quantized model is included + """ + import json + + # Load config to get model details + config_path = output_dir / "config.json" + with open(config_path) as f: + config = json.load(f) + + chronos_config = config.get("chronos_config", {}) + + readme_content = f"""--- +library_name: transformers.js +tags: + - time-series + - forecasting + - chronos + - onnx +pipeline_tag: time-series-forecasting +--- + +# Chronos-2 ONNX + +This is an ONNX export of the [Chronos-2]({model_id}) time series forecasting model, optimized for use with [transformers.js](https://huggingface.co/docs/transformers.js). + +## Model Details + +- **Model Type:** Time Series Forecasting +- **Architecture:** T5-based encoder-decoder with patching +- **Context Length:** {chronos_config.get("context_length", 8192)} timesteps +- **Output Patch Size:** {chronos_config.get("input_patch_size", 16)} timesteps +- **Quantile Levels:** {len(chronos_config.get("quantiles", []))} levels (0.01, 0.05, ..., 0.95, 0.99) +- **Model Dimension:** {config.get("d_model", 768)} +- **Layers:** {config.get("num_layers", 12)} +- **Attention Heads:** {config.get("num_heads", 12)} + +## Files + +- `model.onnx` - FP32 ONNX model ({(output_dir / "model.onnx").stat().st_size / (1024**2):.1f} MB) +{"- `model_quantized.onnx` - INT8 quantized model (" + f"{(output_dir / 'model_quantized.onnx').stat().st_size / (1024**2):.1f}" + " MB, 72% size reduction)" if quantized and (output_dir / "model_quantized.onnx").exists() else ""} +- `config.json` - Model configuration +- `generation_config.json` - Generation parameters +- `onnx/` - transformers.js-compatible directory structure + +## Usage + +### JavaScript (transformers.js) + +```javascript +import {{ pipeline }} from '@huggingface/transformers'; + +// Load the forecasting pipeline +const forecaster = await pipeline('time-series-forecasting', 'kashif/chronos-2-onnx'); + +// Your historical time series data +const timeSeries = [605, 586, 586, 559, 511, 487, 484, 458, ...]; // 100+ timesteps + +// Generate 16-step forecast with quantiles +const output = await forecaster(timeSeries, {{ + prediction_length: 16, + quantile_levels: [0.1, 0.5, 0.9], // 10th, 50th (median), 90th percentiles +}}); + +// Output format: {{ forecast: [[t1_q1, t1_q2, t1_q3], ...], quantile_levels: [...] }} +console.log('Median forecast:', output.forecast.map(row => row[1])); // Extract median + +// Clean up +await forecaster.dispose(); +``` + +### Batch Forecasting + +```javascript +const batch = [ + [100, 110, 105, 115, 120, ...], // Series 1 + [50, 55, 52, 58, 60, ...], // Series 2 +]; + +const outputs = await forecaster(batch); +// Returns array of forecasts, one per input series +``` + +## Performance + +- **Inference Time:** ~35-80ms per series (CPU, Node.js) +- **Speedup vs PyTorch:** 3-8x faster +- **Accuracy:** <1% error vs PyTorch reference + +## Technical Details + +### Preprocessing + +Chronos-2 uses automatic preprocessing: +1. **Repeat-padding:** Input is padded to be divisible by patch_size (16) +2. **Instance normalization:** Per-series z-score normalization +3. **arcsinh transformation:** Nonlinear transformation for better modeling + +All preprocessing is handled automatically by the pipeline. + +### Output Format + +The model outputs quantile forecasts: + +```typescript +interface Chronos2Output {{ + forecast: number[][]; // [prediction_length, num_quantiles] + quantile_levels: number[]; // The quantile levels for each column +}} +``` + +Extract specific quantiles: +```javascript +const median = output.forecast.map(row => row[1]); // 50th percentile +const lower = output.forecast.map(row => row[0]); // 10th percentile (lower bound) +const upper = output.forecast.map(row => row[2]); // 90th percentile (upper bound) +``` + +## Limitations + +- **Maximum context:** {chronos_config.get("context_length", 8192)} timesteps +- **Fixed prediction length:** 16 timesteps (for now; autoregressive unrolling coming soon) +- **Univariate only:** Single time series per input (multivariate support coming) + +## Citation + +```bibtex +@article{{ansari2024chronos, + title={{Chronos: Learning the Language of Time Series}}, + author={{Ansari, Abdul Fatir and others}}, + journal={{arXiv preprint arXiv:2403.07815}}, + year={{2024}} +}} +``` + +## License + +Apache 2.0 + +## Links + +- [Chronos-2 Paper](https://arxiv.org/abs/2403.07815) +- [Chronos GitHub](https://github.com/amazon-science/chronos-forecasting) +- [transformers.js Documentation](https://huggingface.co/docs/transformers.js) +""" + + readme_path = output_dir / "README.md" + with open(readme_path, "w") as f: + f.write(readme_content) + + logger.info(f" Generated {readme_path}") + + +def push_to_hub(output_dir: Path, repo_id: str, private: bool = False): + """ + Push the model to HuggingFace Hub. + + Args: + output_dir: Directory containing the model files + repo_id: Hub repository ID (e.g., 'username/chronos-2-onnx') + private: Whether to make the repository private + """ + try: + from huggingface_hub import HfApi, create_repo + except ImportError: + logger.error("huggingface_hub not installed. Install with: pip install huggingface-hub") + raise + + logger.info(f"\nPushing to HuggingFace Hub: {repo_id}") + + api = HfApi() + + # Create repo if it doesn't exist + try: + create_repo(repo_id, private=private, exist_ok=True) + logger.info(f" Repository created/verified: https://huggingface.co/{repo_id}") + except Exception as e: + logger.warning(f" Could not create repo: {e}") + + # Upload all files + logger.info(" Uploading files...") + + files_to_upload = [ + "model.onnx", + "config.json", + "generation_config.json", + "README.md", + ] + + # Add quantized model if it exists + if (output_dir / "model_quantized.onnx").exists(): + files_to_upload.append("model_quantized.onnx") + + # Upload onnx/ directory + for file in files_to_upload: + file_path = output_dir / file + if file_path.exists(): + api.upload_file( + path_or_fileobj=str(file_path), + path_in_repo=file, + repo_id=repo_id, + repo_type="model", + ) + logger.info(f" ✓ {file}") + + # Upload onnx/ directory symlinks (as actual files) + onnx_dir = output_dir / "onnx" + if onnx_dir.exists(): + for file in ["encoder_model.onnx", "decoder_model_merged.onnx"]: + src_path = output_dir / "model.onnx" + if src_path.exists(): + api.upload_file( + path_or_fileobj=str(src_path), + path_in_repo=f"onnx/{file}", + repo_id=repo_id, + repo_type="model", + ) + logger.info(f" ✓ onnx/{file}") + + logger.info(f"\n✓ Successfully pushed to: https://huggingface.co/{repo_id}") + + +def validate_onnx_export( + onnx_path: Path, + model_id: str, + device: str = None, + rtol: float = 1e-3, + atol: float = 1e-3, +) -> bool: + """ + Validate ONNX export by comparing outputs with PyTorch model. + + Args: + onnx_path: Path to ONNX model + model_id: Original model ID + device: Device to use + rtol: Relative tolerance for comparison + atol: Absolute tolerance for comparison + + Returns: + True if validation passes + """ + logger.info("Validating ONNX export...") + + # Auto-detect device + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Load PyTorch model + # Official model is now available at: https://huggingface.co/amazon/chronos-2 + pipeline = Chronos2Pipeline.from_pretrained(model_id, device_map=device) + + model = pipeline.model + model.eval() + + # Load ONNX model + import onnxruntime as ort + + logger.info(f"Loading ONNX model from {onnx_path}") + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"] + ort_session = ort.InferenceSession(str(onnx_path), providers=providers) + + # Create test inputs + batch_size = 4 + context_length = 256 + num_output_patches = 2 + + dummy_inputs = create_dummy_inputs( + batch_size=batch_size, + context_length=context_length, + num_output_patches=num_output_patches, + include_future_covariates=False, + output_patch_size=model.chronos_config.output_patch_size, + device=device, + ) + + # Run PyTorch inference + logger.info("Running PyTorch inference...") + with torch.no_grad(): + wrapped_model = Chronos2ONNXWrapper(model) + pytorch_output = wrapped_model( + context=dummy_inputs["context"], + group_ids=dummy_inputs["group_ids"], + attention_mask=dummy_inputs["attention_mask"], + future_covariates=None, + num_output_patches=dummy_inputs["num_output_patches"], + ) + + # Run ONNX inference (num_output_patches is fixed in the model, not an input) + logger.info("Running ONNX inference...") + ort_inputs = { + "context": dummy_inputs["context"].cpu().numpy(), + "group_ids": dummy_inputs["group_ids"].cpu().numpy(), + "attention_mask": dummy_inputs["attention_mask"].cpu().numpy(), + } + + onnx_output = ort_session.run(None, ort_inputs)[0] + + # Compare outputs + pytorch_output_np = pytorch_output.cpu().numpy() + + logger.info(f"PyTorch output shape: {pytorch_output_np.shape}") + logger.info(f"ONNX output shape: {onnx_output.shape}") + + # Check shapes match + if pytorch_output_np.shape != onnx_output.shape: + logger.error(f"Output shapes don't match! PyTorch: {pytorch_output_np.shape}, ONNX: {onnx_output.shape}") + return False + + # Check values match + max_diff = np.abs(pytorch_output_np - onnx_output).max() + mean_diff = np.abs(pytorch_output_np - onnx_output).mean() + + logger.info(f"Max absolute difference: {max_diff:.6f}") + logger.info(f"Mean absolute difference: {mean_diff:.6f}") + + if np.allclose(pytorch_output_np, onnx_output, rtol=rtol, atol=atol): + logger.info("✓ Validation PASSED: ONNX output matches PyTorch output") + return True + else: + logger.error("✗ Validation FAILED: ONNX output doesn't match PyTorch output") + logger.error(f"Relative tolerance: {rtol}, Absolute tolerance: {atol}") + return False + + +def main(): + parser = argparse.ArgumentParser(description="Export Chronos-2 model to ONNX format") + parser.add_argument( + "--model_id", + type=str, + default="amazon/chronos-2-small", + help="HuggingFace model ID or local path (e.g., 'amazon/chronos-2-small')", + ) + parser.add_argument("--output_dir", type=str, default="./chronos2-onnx", help="Output directory for ONNX model") + parser.add_argument("--opset_version", type=int, default=17, help="ONNX opset version (default: 17)") + parser.add_argument("--fp16", action="store_true", help="Export model in FP16 precision") + parser.add_argument( + "--validate", action="store_true", help="Validate ONNX export by comparing with PyTorch outputs" + ) + parser.add_argument( + "--no_future_covariates", action="store_true", help="Don't include future covariates support in export" + ) + parser.add_argument( + "--device", type=str, default=None, choices=["cpu", "cuda"], help="Device to use (default: auto-detect)" + ) + parser.add_argument("--quantize", action="store_true", help="Quantize the model to INT8 after export") + parser.add_argument( + "--push_to_hub", + type=str, + default=None, + help="Push the exported model to HuggingFace Hub (e.g., 'username/chronos-2-onnx')", + ) + parser.add_argument("--private", action="store_true", help="Make the Hub repository private") + + args = parser.parse_args() + + output_dir = Path(args.output_dir) + + try: + # Export model + logger.info("=" * 60) + logger.info("Chronos-2 ONNX Export Pipeline") + logger.info("=" * 60 + "\n") + + onnx_path = export_to_onnx( + model_id=args.model_id, + output_dir=output_dir, + opset_version=args.opset_version, + use_fp16=args.fp16, + include_future_covariates=not args.no_future_covariates, + device=args.device, + ) + + # Validate if requested + if args.validate: + logger.info("\n" + "=" * 60) + logger.info("Validation") + logger.info("=" * 60 + "\n") + + validation_passed = validate_onnx_export( + onnx_path=onnx_path, + model_id=args.model_id, + device=args.device, + ) + + if not validation_passed: + logger.warning("Validation failed, but ONNX model was still exported") + return 1 + + # Quantize if requested + quantized_path = None + if args.quantize: + logger.info("\n" + "=" * 60) + logger.info("Quantization") + logger.info("=" * 60 + "\n") + + quantized_path = quantize_model(onnx_path) + + # Setup transformers.js directory structure + logger.info("\n" + "=" * 60) + logger.info("transformers.js Setup") + logger.info("=" * 60 + "\n") + + setup_transformersjs_structure(output_dir) + + # Generate README + logger.info("\n" + "=" * 60) + logger.info("README Generation") + logger.info("=" * 60 + "\n") + + generate_readme(output_dir, args.model_id, quantized=args.quantize) + + # Push to Hub if requested + if args.push_to_hub: + logger.info("\n" + "=" * 60) + logger.info("Hub Upload") + logger.info("=" * 60 + "\n") + + push_to_hub(output_dir, args.push_to_hub, private=args.private) + + # Final summary + logger.info("\n" + "=" * 60) + logger.info("Export Complete!") + logger.info("=" * 60) + logger.info(f" ONNX model: {onnx_path}") + if quantized_path: + logger.info(f" Quantized: {quantized_path}") + logger.info(f" Config: {output_dir / 'config.json'}") + logger.info(f" README: {output_dir / 'README.md'}") + if args.push_to_hub: + logger.info(f" Hub URL: https://huggingface.co/{args.push_to_hub}") + logger.info("=" * 60 + "\n") + + return 0 + + except Exception as e: + logger.error(f"Export failed with error: {e}", exc_info=True) + return 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/scripts/onnx/fix_onnx_model.py b/scripts/onnx/fix_onnx_model.py new file mode 100644 index 00000000..04fbb57b --- /dev/null +++ b/scripts/onnx/fix_onnx_model.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 +""" +Fix ONNX model type issues, particularly for Gather operations. + +This script fixes dtype mismatches where float tensors are used as indices +for Gather operations, which require int64 indices. +""" + +import onnx +from onnx import helper, TensorProto +import sys + + +def make_prediction_length_dynamic(model: onnx.ModelProto, dim_name: str = "prediction_length"): + """ + Make the prediction_length dimension (dim 2) of the output dynamic. + + Changes output shape from [batch_size, num_quantiles, 64] to [batch_size, num_quantiles, prediction_length] + where prediction_length is a symbolic dimension. + """ + print("\nMaking prediction_length dimension dynamic...") + + # Update output tensor shapes + for output in model.graph.output: + if output.type.tensor_type.HasField("shape"): + shape = output.type.tensor_type.shape + # Check if this is the quantile_preds output (3D tensor: [batch, quantiles, pred_len]) + if len(shape.dim) == 3: + print(f" Output '{output.name}' shape before:") + for i, dim in enumerate(shape.dim): + if dim.HasField("dim_value"): + print(f" Dim {i}: {dim.dim_value}") + elif dim.HasField("dim_param"): + print(f" Dim {i}: {dim.dim_param} (symbolic)") + + # Make dimension 2 (prediction_length) dynamic + if shape.dim[2].HasField("dim_value"): + original_value = shape.dim[2].dim_value + shape.dim[2].Clear() + shape.dim[2].dim_param = dim_name + print(f" Changed dim 2 from {original_value} to '{dim_name}' (dynamic)") + + return model + + +def fix_gather_indices(model_path: str, output_path: str, make_dynamic: bool = True): + """ + Fix Gather operation index type issues in ONNX model and optionally make prediction_length dynamic. + + The indices may be represented as float tensors in the graph but Gather + requires int64. This function inserts Cast operations to convert float + indices to int64 before Gather operations. + + Args: + model_path: Path to input ONNX model + output_path: Path to save fixed ONNX model + make_dynamic: If True, also make the prediction_length dimension dynamic + """ + print(f"Loading ONNX model from {model_path}") + model = onnx.load(model_path) + + # Find all Gather nodes and check their index inputs + gather_nodes = [] + + for idx, node in enumerate(model.graph.node): + if node.op_type == "Gather": + gather_nodes.append((idx, node)) + if len(node.input) >= 2: + index_input = node.input[1] + print(f"Gather node {node.name or 'unnamed'} uses indices: {index_input}") + + print(f"\nFound {len(gather_nodes)} Gather operations") + + # Insert Cast nodes before Gather operations to convert float indices to int64 + print("\nInserting Cast operations for float->int64 conversion...") + cast_count = 0 + + for idx, gather_node in gather_nodes: + if len(gather_node.input) < 2: + continue + + index_input = gather_node.input[1] + + # Create a unique name for the cast output + cast_output_name = f"{index_input}_int64_cast" + + # Create Cast node: float -> int64 + cast_node = helper.make_node( + "Cast", + inputs=[index_input], + outputs=[cast_output_name], + to=TensorProto.INT64, + name=f"cast_{index_input}_to_int64", + ) + + # Modify the Gather node to use the cast output + new_gather_input = [gather_node.input[0], cast_output_name] + if len(gather_node.input) > 2: + new_gather_input.extend(gather_node.input[2:]) + + # Update the gather node's inputs + del gather_node.input[:] + gather_node.input.extend(new_gather_input) + + # Add the cast node before this gather node + model.graph.node.insert(idx + cast_count, cast_node) + cast_count += 1 + + print(f" Added Cast node before {gather_node.name or 'unnamed'}") + + print(f"Added {cast_count} Cast operations before Gather nodes") + + # Fix Concat operations that might have dtype mismatches + # Cast all int64 inputs back to float32 before Concat + print("\nFixing Concat operations with dtype mismatches...") + concat_cast_count = 0 + + concat_nodes = [] + for idx, node in enumerate(model.graph.node): + if node.op_type == "Concat": + concat_nodes.append((idx, node)) + + print(f"Found {len(concat_nodes)} Concat operations") + + for idx, concat_node in concat_nodes: + # For each Concat input that might be int64, cast it back to float32 + new_inputs = [] + for i, input_name in enumerate(concat_node.input): + # Check if this input came from a Cast operation (has "_int64_cast" in name) + if "_int64_cast" in input_name: + # This was cast to int64 for Gather, need to cast back to float for Concat + cast_output_name = f"{input_name}_back_to_float32" + + cast_node = helper.make_node( + "Cast", + inputs=[input_name], + outputs=[cast_output_name], + to=TensorProto.FLOAT, + name=f"cast_{input_name}_back_to_float", + ) + + # Insert cast node before concat + model.graph.node.insert(idx + concat_cast_count, cast_node) + concat_cast_count += 1 + + new_inputs.append(cast_output_name) + print(f" Adding Cast int64→float32 before Concat {concat_node.name or 'unnamed'} input {i}") + else: + new_inputs.append(input_name) + + # Update concat inputs + if new_inputs != list(concat_node.input): + del concat_node.input[:] + concat_node.input.extend(new_inputs) + + print(f"Added {concat_cast_count} Cast operations before Concat nodes") + + # Make prediction_length dimension dynamic + if make_dynamic: + model = make_prediction_length_dynamic(model) + + # Validate and save + print("\nValidating fixed model...") + try: + onnx.checker.check_model(model) + print("✓ Model validation passed!") + except Exception as e: + print(f"⚠ Validation warnings: {e}") + print(" Attempting to save anyway...") + + print(f"\nSaving fixed model to {output_path}") + onnx.save(model, output_path) + print("✓ Model saved successfully!") + + return True + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Fix ONNX model type issues") + parser.add_argument("input", help="Input ONNX model path") + parser.add_argument("output", help="Output ONNX model path") + + args = parser.parse_args() + + try: + fix_gather_indices(args.input, args.output) + print("\n✓ Model fixed successfully!") + sys.exit(0) + except Exception as e: + print(f"\n✗ Error: {e}", file=sys.stderr) + import traceback + + traceback.print_exc() + sys.exit(1) diff --git a/scripts/onnx/quantize_chronos2.py b/scripts/onnx/quantize_chronos2.py new file mode 100644 index 00000000..343f53fe --- /dev/null +++ b/scripts/onnx/quantize_chronos2.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python3 +""" +Quantize Chronos-2 ONNX model to reduce size and improve inference speed. + +This script quantizes the ONNX model from FP32 to INT8, reducing model size +by approximately 75% while maintaining good accuracy. + +Usage: + python quantize_chronos2.py \ + --input chronos2-onnx/model.onnx \ + --output chronos2-onnx/model_quantized.onnx \ + --mode dynamic + +Quantization Modes: + - dynamic: Dynamic quantization (fastest, best compatibility) + - static: Static quantization (requires calibration data, best accuracy) + - qat: Quantization-aware training (requires retraining) +""" + +import argparse +import logging +from pathlib import Path + +import numpy as np + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +def dynamic_quantization(model_path: str, output_path: str): + """ + Apply dynamic quantization to the ONNX model. + + Dynamic quantization converts weights to INT8 at export time and + activations to INT8 dynamically at runtime. + + Pros: + - No calibration data needed + - 4x smaller model size + - Faster inference on CPU + - Good accuracy (typically <1% loss) + + Cons: + - Activations still computed in FP32 then converted + - Less speedup than static quantization + """ + from onnxruntime.quantization import quantize_dynamic, QuantType + + logger.info(f"Loading model from {model_path}") + + logger.info("Applying dynamic quantization...") + logger.info(" - Weight type: INT8") + logger.info(" - Activation type: INT8 (dynamic)") + + quantize_dynamic( + model_input=model_path, + model_output=output_path, + weight_type=QuantType.QInt8, + ) + + logger.info(f"Quantized model saved to {output_path}") + + +def static_quantization(model_path: str, output_path: str, calibration_data_path: str = None): + """ + Apply static quantization to the ONNX model. + + Static quantization requires calibration data to determine optimal + quantization parameters for both weights and activations. + + Pros: + - Best inference speed + - Smallest model size + - Activations also quantized + + Cons: + - Requires representative calibration data + - More complex setup + - Potential accuracy loss if calibration data not representative + """ + from onnxruntime.quantization import quantize_static, QuantType, CalibrationDataReader + + logger.info(f"Loading model from {model_path}") + + # Create calibration data reader + if calibration_data_path: + logger.info(f"Loading calibration data from {calibration_data_path}") + # Custom calibration data reader would go here + raise NotImplementedError("Custom calibration data reader not implemented yet") + else: + logger.info("Generating synthetic calibration data...") + + class SyntheticCalibrationDataReader(CalibrationDataReader): + def __init__(self, num_samples=100): + self.num_samples = num_samples + self.current_sample = 0 + self.batch_size = 1 + self.context_length = 512 + + def get_next(self): + if self.current_sample >= self.num_samples: + return None + + # Generate synthetic time series data + context = np.random.randn(self.batch_size, self.context_length).astype(np.float32) + group_ids = np.array([0], dtype=np.int64) + attention_mask = np.ones((self.batch_size, self.context_length), dtype=np.float32) + + self.current_sample += 1 + + return { + "context": context, + "group_ids": group_ids, + "attention_mask": attention_mask, + } + + calibration_data_reader = SyntheticCalibrationDataReader() + + logger.info("Applying static quantization...") + logger.info(" - Weight type: INT8") + logger.info(" - Activation type: INT8 (static)") + logger.info(" - Calibration samples: 100") + + quantize_static( + model_input=model_path, + model_output=output_path, + calibration_data_reader=calibration_data_reader, + quant_format=QuantType.QInt8, + ) + + logger.info(f"Quantized model saved to {output_path}") + + +def compare_models(original_path: str, quantized_path: str): + """Compare original and quantized model sizes.""" + + original_size = Path(original_path).stat().st_size / (1024**2) # MB + quantized_size = Path(quantized_path).stat().st_size / (1024**2) # MB + + reduction = (1 - quantized_size / original_size) * 100 + + logger.info(f"\n{'=' * 60}") + logger.info("Model Size Comparison:") + logger.info(f" Original: {original_size:.1f} MB") + logger.info(f" Quantized: {quantized_size:.1f} MB") + logger.info(f" Reduction: {reduction:.1f}%") + logger.info(f"{'=' * 60}\n") + + +def validate_quantized_model(model_path: str): + """Validate the quantized model can be loaded and run.""" + + logger.info("Validating quantized model...") + + try: + import onnxruntime as ort + + # Load model + session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) + + # Create test input + batch_size = 1 + context_length = 256 + + inputs = { + "context": np.random.randn(batch_size, context_length).astype(np.float32), + "group_ids": np.array([0], dtype=np.int64), + "attention_mask": np.ones((batch_size, context_length), dtype=np.float32), + } + + # Run inference + logger.info(" Running test inference...") + outputs = session.run(None, inputs) + + logger.info(" ✓ Inference successful!") + logger.info(f" Output shape: {outputs[0].shape}") + logger.info(f" Output dtype: {outputs[0].dtype}") + + return True + + except Exception as e: + logger.error(f" ✗ Validation failed: {e}") + return False + + +def main(): + parser = argparse.ArgumentParser(description="Quantize Chronos-2 ONNX model") + parser.add_argument("--input", type=str, default="chronos2-onnx/model.onnx", help="Input ONNX model path") + parser.add_argument( + "--output", type=str, default="chronos2-onnx/model_quantized.onnx", help="Output quantized model path" + ) + parser.add_argument( + "--mode", + type=str, + default="dynamic", + choices=["dynamic", "static"], + help="Quantization mode (dynamic or static)", + ) + parser.add_argument( + "--calibration_data", type=str, default=None, help="Path to calibration data (for static quantization)" + ) + parser.add_argument("--validate", action="store_true", help="Validate quantized model after export") + + args = parser.parse_args() + + logger.info("=" * 60) + logger.info("Chronos-2 ONNX Model Quantization") + logger.info("=" * 60) + + # Check if onnxruntime is installed + try: + import onnxruntime + + logger.info(f"ONNX Runtime version: {onnxruntime.__version__}") + except ImportError: + logger.error("onnxruntime not installed. Install with: pip install onnxruntime") + return 1 + + # Run quantization + try: + if args.mode == "dynamic": + dynamic_quantization(args.input, args.output) + elif args.mode == "static": + static_quantization(args.input, args.output, args.calibration_data) + + # Compare sizes + compare_models(args.input, args.output) + + # Validate if requested + if args.validate: + if validate_quantized_model(args.output): + logger.info("✓ Quantization completed successfully!") + return 0 + else: + logger.warning("⚠ Quantization completed but validation failed") + return 1 + else: + logger.info("✓ Quantization completed successfully!") + logger.info(" (Use --validate to test the quantized model)") + return 0 + + except Exception as e: + logger.error(f"✗ Quantization failed: {e}") + import traceback + + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + import sys + + sys.exit(main()) From a712c54f4396b8704557c5cc0c9df99ab4c800df Mon Sep 17 00:00:00 2001 From: James LePage <36246732+Jameswlepage@users.noreply.github.com> Date: Mon, 15 Jun 2026 23:57:36 -0400 Subject: [PATCH 2/2] Add Chronos-2 ONNX covariate export validation --- scripts/onnx/README.md | 129 +++++++++++ scripts/onnx/export_chronos2_to_onnx.py | 285 +++++++++++++++--------- scripts/onnx/fix_onnx_model.py | 74 +++++- scripts/onnx/quantize_chronos2.py | 98 ++++++-- scripts/onnx/validate_chronos2_onnx.py | 227 +++++++++++++++++++ src/chronos/chronos_bolt.py | 17 +- 6 files changed, 696 insertions(+), 134 deletions(-) create mode 100644 scripts/onnx/README.md create mode 100644 scripts/onnx/validate_chronos2_onnx.py diff --git a/scripts/onnx/README.md b/scripts/onnx/README.md new file mode 100644 index 00000000..38ce69e2 --- /dev/null +++ b/scripts/onnx/README.md @@ -0,0 +1,129 @@ +# Chronos-2 ONNX Export + +These scripts export the Chronos-2 tensor model to ONNX, repair the exported graph, and validate ONNX Runtime output against the PyTorch model. + +The exporter writes a real ONNX model. It does not commit or vendor generated model artifacts into this repository. + +## Install + +Install Chronos and the ONNX dependencies in an environment with a current PyTorch release: + +```bash +pip install torch onnx onnxruntime onnxscript transformers chronos-forecasting +``` + +Use `onnxruntime-gpu` instead of `onnxruntime` if you want CUDA inference. + +## Export + +Export the public Chronos-2 model with future covariates enabled: + +```bash +python scripts/onnx/export_chronos2_to_onnx.py \ + --model_id amazon/chronos-2 \ + --output_dir chronos2-onnx \ + --validate +``` + +The exporter first writes `model_raw.onnx`, then runs `fix_onnx_model.py` and writes the final loadable model to `model.onnx`. The raw model is deleted unless `--keep_raw_onnx` is passed. + +Important options: + +- `--context_length`: fixed context length to trace into the ONNX graph. Default: `512`. +- `--num_output_patches`: fixed number of output patches to trace. Default: `4`. +- `--no_future_covariates`: export without the `future_covariates` input. +- `--no_fix_onnx`: skip the graph repair pass. This is useful only for debugging; the raw graph may not load in ONNX Runtime. +- `--quantize`: additionally write a dynamic INT8 quantized model. + +For the default Chronos-2 config, `output_patch_size=16`, so `--num_output_patches 4` exports a 64-step horizon. + +## Validate Parity + +The export script can run a basic PyTorch-vs-ONNX validation with `--validate`. For fuller coverage, run the standalone parity harness: + +```bash +python scripts/onnx/validate_chronos2_onnx.py \ + --model_id amazon/chronos-2 \ + --onnx_path chronos2-onnx/model.onnx \ + --context_length 512 \ + --num_output_patches 4 \ + --report_path chronos2-onnx/parity_report.json +``` + +The harness compares the ONNX output with the PyTorch wrapper across several cases: + +- dynamic batch sizes +- shared and distinct `group_ids` +- sinusoidal, random, and zero future covariates +- missing context values +- missing future covariate values + +It exits nonzero if any case fails `np.allclose`. + +## Tensor Interface + +The exported model exposes the Chronos-2 tensor interface used by `Chronos2Model.forward`. + +Inputs: + +- `context`: float32 tensor shaped `[batch, context_length]`. +- `group_ids`: int64 tensor shaped `[batch]`. Series with equal IDs form an attention group. +- `attention_mask`: float32 tensor shaped `[batch, context_length]`, where `1` marks observed positions and `0` marks masked positions. +- `future_covariates`: optional float32 tensor shaped `[batch, prediction_length]`, present unless `--no_future_covariates` is used. +- `num_output_patches`: optional int64 scalar. Some PyTorch legacy exports expose this scalar input. If present, feed the same value used during export. + +Output: + +- `quantile_preds`: float32 tensor shaped `[batch, num_quantiles, prediction_length]`. + +`prediction_length = num_output_patches * output_patch_size`. The default export for `amazon/chronos-2` is `[batch, 21, 64]`. + +Minimal ONNX Runtime call: + +```python +import numpy as np +import onnxruntime as ort + +session = ort.InferenceSession("chronos2-onnx/model.onnx", providers=["CPUExecutionProvider"]) +input_names = {input_.name for input_ in session.get_inputs()} + +batch_size = 2 +context_length = 512 +num_output_patches = 4 +prediction_length = 64 + +inputs = { + "context": np.random.randn(batch_size, context_length).astype(np.float32), + "group_ids": np.arange(batch_size, dtype=np.int64), + "attention_mask": np.ones((batch_size, context_length), dtype=np.float32), +} + +if "future_covariates" in input_names: + inputs["future_covariates"] = np.random.randn(batch_size, prediction_length).astype(np.float32) + +if "num_output_patches" in input_names: + inputs["num_output_patches"] = np.array(num_output_patches, dtype=np.int64) + +quantile_preds = session.run(None, inputs)[0] +``` + +## Repairing a Raw Export + +`fix_onnx_model.py` repairs Gather index dtype mismatches emitted by the legacy PyTorch exporter: + +```bash +python scripts/onnx/fix_onnx_model.py model_raw.onnx model.onnx +``` + +The fixer does not mark prediction length dynamic by default, because the traced ONNX graph has a fixed executable horizon. For covariate exports it infers the fixed output length from `future_covariates`; for non-covariate exports you can pass `--prediction_length`. `--dynamic_prediction_length` only changes output shape metadata and should not be treated as runtime support for arbitrary horizons. + +## Supported Shapes and Limitations + +- Batch size is dynamic. +- Context length is fixed at export time. +- Future covariate length is fixed at export time and should match the exported prediction length. +- Prediction length is fixed at export time. +- The tensor-level export does not include the `Chronos2Pipeline.predict` list-of-dicts/DataFrame preprocessing wrapper. Prepare `context`, `group_ids`, `attention_mask`, and optional `future_covariates` tensors before calling ONNX Runtime. +- Missing future covariate values can be represented as `NaN`; Chronos-2 infers the future covariate mask from those values when no explicit mask is exported. +- Quantized models should be validated separately. Dynamic quantization can change numeric parity. +- Browser, server, and application packaging are intentionally outside this export contract. diff --git a/scripts/onnx/export_chronos2_to_onnx.py b/scripts/onnx/export_chronos2_to_onnx.py index f9cea2e8..c89dffcd 100755 --- a/scripts/onnx/export_chronos2_to_onnx.py +++ b/scripts/onnx/export_chronos2_to_onnx.py @@ -3,22 +3,22 @@ # SPDX-License-Identifier: Apache-2.0 """ -Export Chronos-2 models to ONNX format for use with transformers.js +Export Chronos-2 models to ONNX format. This script: 1. Loads a pretrained Chronos-2 model -2. Exports it to ONNX format with proper dynamic axes +2. Exports it to ONNX format with dynamic batch size 3. Validates the ONNX export by comparing outputs with PyTorch 4. Optionally quantizes the model for smaller size Usage: python export_chronos2_to_onnx.py \ - --model_id amazon/chronos-2-small \ - --output_dir ./chronos2-small-onnx \ + --model_id amazon/chronos-2 \ + --output_dir ./chronos2-onnx \ --validate Requirements: - pip install torch onnx onnxruntime transformers chronos-forecasting + pip install torch onnx onnxruntime onnxscript transformers chronos-forecasting """ import argparse @@ -82,12 +82,11 @@ def forward( group_ids: Group IDs tensor of shape (batch_size,) attention_mask: Optional attention mask of shape (batch_size, context_length) future_covariates: Optional future covariates of shape (batch_size, future_length) - num_output_patches: Number of output patches to generate (int, will be symbolic in ONNX) + num_output_patches: Number of output patches to generate Returns: quantile_preds: Tensor of shape (batch_size, num_quantiles, prediction_length) """ - # Prepare kwargs - num_output_patches is now directly an int that ONNX can trace symbolically kwargs = { "context": context, "group_ids": group_ids, @@ -133,7 +132,7 @@ def create_dummy_inputs( "context": torch.randn(batch_size, context_length, device=device, dtype=torch.float32), "group_ids": torch.arange(batch_size, device=device, dtype=torch.long), "attention_mask": torch.ones(batch_size, context_length, device=device, dtype=torch.float32), - "num_output_patches": num_output_patches, # int value, will be fixed in ONNX + "num_output_patches": num_output_patches, } if include_future_covariates: @@ -149,7 +148,10 @@ def export_to_onnx( opset_version: int = 17, use_fp16: bool = False, include_future_covariates: bool = True, + context_length: int | None = None, + num_output_patches: int = 4, device: str = None, + onnx_filename: str = "model_raw.onnx", ) -> Path: """ Export Chronos-2 model to ONNX format. @@ -160,7 +162,10 @@ def export_to_onnx( opset_version: ONNX opset version (17 recommended for best compatibility) use_fp16: Whether to use FP16 precision include_future_covariates: Whether to support future covariates in export + context_length: Context length to trace into the model. If omitted, uses min(512, model context length). + num_output_patches: Number of output patches to trace into the model. device: Device to use ('cuda' or 'cpu') + onnx_filename: File name for the raw exported ONNX model Returns: Path to exported ONNX model @@ -203,11 +208,25 @@ def export_to_onnx( # Create dummy inputs batch_size = 2 - context_length = min(512, chronos_config.context_length) # Use smaller context for export - # Export with num_output_patches=4 to support up to 64-step predictions (4 * 16 = 64) - # ONNX models have fixed output shapes - transformers.js will truncate to requested prediction_length - # This matches how the original chronos2 Python code works with dynamic num_output_patches - num_output_patches = 4 + if context_length is None: + context_length = min(512, chronos_config.context_length) + + if context_length <= 0: + raise ValueError(f"context_length must be positive, found {context_length}") + + if context_length > chronos_config.context_length: + raise ValueError( + f"context_length={context_length} exceeds model context length {chronos_config.context_length}" + ) + + if num_output_patches <= 0: + raise ValueError(f"num_output_patches must be positive, found {num_output_patches}") + + if num_output_patches > chronos_config.max_output_patches: + raise ValueError( + f"num_output_patches={num_output_patches} exceeds model maximum " + f"{chronos_config.max_output_patches}" + ) dummy_inputs = create_dummy_inputs( batch_size=batch_size, @@ -218,17 +237,19 @@ def export_to_onnx( device=device, ) - # Define dynamic axes for variable batch size and context length - # Note: prediction_length is fixed based on num_output_patches=4 (64 steps) + prediction_length = num_output_patches * chronos_config.output_patch_size + + # Keep sequence lengths fixed because legacy ONNX export cannot lower the patching path + # when the patched sequence dimension is symbolic. Batch remains dynamic. dynamic_axes = { - "context": {0: "batch_size", 1: "context_length"}, + "context": {0: "batch_size"}, "group_ids": {0: "batch_size"}, - "attention_mask": {0: "batch_size", 1: "context_length"}, - "quantile_preds": {0: "batch_size"}, # prediction_length (dim 2) is fixed at 64 + "attention_mask": {0: "batch_size"}, + "quantile_preds": {0: "batch_size"}, } if include_future_covariates: - dynamic_axes["future_covariates"] = {0: "batch_size", 1: "future_length"} + dynamic_axes["future_covariates"] = {0: "batch_size"} # Prepare ONNX export args based on whether future_covariates are included if include_future_covariates: @@ -238,7 +259,7 @@ def export_to_onnx( dummy_inputs["group_ids"], dummy_inputs["attention_mask"], dummy_inputs["future_covariates"], - dummy_inputs["num_output_patches"], # Passed to wrapper but not an ONNX input + dummy_inputs["num_output_patches"], ) else: input_names = ["context", "group_ids", "attention_mask"] @@ -247,17 +268,18 @@ def export_to_onnx( dummy_inputs["group_ids"], dummy_inputs["attention_mask"], None, # No future_covariates - dummy_inputs["num_output_patches"], # Passed to wrapper but not an ONNX input + dummy_inputs["num_output_patches"], ) output_names = ["quantile_preds"] # Create output directory output_dir.mkdir(parents=True, exist_ok=True) - onnx_path = output_dir / "model.onnx" + onnx_path = output_dir / onnx_filename logger.info(f"Exporting model to ONNX format at {onnx_path}") logger.info(f"Dynamic axes: {dynamic_axes}") + logger.info(f"Fixed context_length={context_length}, prediction_length={prediction_length}") # Export to ONNX try: @@ -297,6 +319,7 @@ def export_to_onnx( opset_version=opset_version, do_constant_folding=True, export_params=True, + dynamo=False, verbose=False, ) logger.info("Used legacy TorchScript-based ONNX exporter") @@ -319,6 +342,23 @@ def export_to_onnx( return onnx_path +def fix_onnx_export(raw_onnx_path: Path, fixed_onnx_path: Path) -> Path: + """ + Run the mandatory Chronos-2 ONNX post-processing pass. + + The legacy exporter emits some Gather indices with the wrong dtype. ONNX Runtime + rejects the raw graph until those indices are cast back to int64. + """ + from fix_onnx_model import fix_gather_indices + + logger.info("Fixing exported ONNX graph") + logger.info(f" Raw: {raw_onnx_path}") + logger.info(f" Fixed: {fixed_onnx_path}") + + fix_gather_indices(str(raw_onnx_path), str(fixed_onnx_path), make_dynamic=False) + return fixed_onnx_path + + def quantize_model(onnx_path: Path) -> Path: """ Quantize the ONNX model to INT8. @@ -420,9 +460,17 @@ def generate_readme(output_dir: Path, model_id: str, quantized: bool = False): config = json.load(f) chronos_config = config.get("chronos_config", {}) + output_patch_size = chronos_config.get("output_patch_size", 16) + quantiles = chronos_config.get("quantiles", []) + quantized_line = "" + if quantized and (output_dir / "model_quantized.onnx").exists(): + quantized_line = ( + f"- `model_quantized.onnx` - INT8 dynamic-quantized ONNX model " + f"({(output_dir / 'model_quantized.onnx').stat().st_size / (1024**2):.1f} MB)\n" + ) readme_content = f"""--- -library_name: transformers.js +library_name: onnx tags: - time-series - forecasting @@ -433,15 +481,16 @@ def generate_readme(output_dir: Path, model_id: str, quantized: bool = False): # Chronos-2 ONNX -This is an ONNX export of the [Chronos-2]({model_id}) time series forecasting model, optimized for use with [transformers.js](https://huggingface.co/docs/transformers.js). +This is a tensor-level ONNX export of the [Chronos-2]({model_id}) time series forecasting model. ## Model Details - **Model Type:** Time Series Forecasting -- **Architecture:** T5-based encoder-decoder with patching -- **Context Length:** {chronos_config.get("context_length", 8192)} timesteps -- **Output Patch Size:** {chronos_config.get("input_patch_size", 16)} timesteps -- **Quantile Levels:** {len(chronos_config.get("quantiles", []))} levels (0.01, 0.05, ..., 0.95, 0.99) +- **Architecture:** Chronos-2 encoder-decoder with patching +- **Maximum Model Context Length:** {chronos_config.get("context_length", 8192)} timesteps +- **Input Patch Size:** {chronos_config.get("input_patch_size", 16)} timesteps +- **Output Patch Size:** {output_patch_size} timesteps +- **Quantile Levels:** {len(quantiles)} levels - **Model Dimension:** {config.get("d_model", 768)} - **Layers:** {config.get("num_layers", 12)} - **Attention Heads:** {config.get("num_heads", 12)} @@ -449,89 +498,58 @@ def generate_readme(output_dir: Path, model_id: str, quantized: bool = False): ## Files - `model.onnx` - FP32 ONNX model ({(output_dir / "model.onnx").stat().st_size / (1024**2):.1f} MB) -{"- `model_quantized.onnx` - INT8 quantized model (" + f"{(output_dir / 'model_quantized.onnx').stat().st_size / (1024**2):.1f}" + " MB, 72% size reduction)" if quantized and (output_dir / "model_quantized.onnx").exists() else ""} +{quantized_line.rstrip()} - `config.json` - Model configuration - `generation_config.json` - Generation parameters -- `onnx/` - transformers.js-compatible directory structure ## Usage -### JavaScript (transformers.js) - -```javascript -import {{ pipeline }} from '@huggingface/transformers'; +Run the exported model with ONNX Runtime: -// Load the forecasting pipeline -const forecaster = await pipeline('time-series-forecasting', 'kashif/chronos-2-onnx'); - -// Your historical time series data -const timeSeries = [605, 586, 586, 559, 511, 487, 484, 458, ...]; // 100+ timesteps +```python +import numpy as np +import onnxruntime as ort -// Generate 16-step forecast with quantiles -const output = await forecaster(timeSeries, {{ - prediction_length: 16, - quantile_levels: [0.1, 0.5, 0.9], // 10th, 50th (median), 90th percentiles -}}); +session = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"]) +input_names = {{input_.name for input_ in session.get_inputs()}} -// Output format: {{ forecast: [[t1_q1, t1_q2, t1_q3], ...], quantile_levels: [...] }} -console.log('Median forecast:', output.forecast.map(row => row[1])); // Extract median +batch_size = 1 +context_length = 512 +num_output_patches = 4 +prediction_length = num_output_patches * {output_patch_size} -// Clean up -await forecaster.dispose(); -``` +inputs = {{ + "context": np.random.randn(batch_size, context_length).astype(np.float32), + "group_ids": np.arange(batch_size, dtype=np.int64), + "attention_mask": np.ones((batch_size, context_length), dtype=np.float32), +}} -### Batch Forecasting +if "future_covariates" in input_names: + inputs["future_covariates"] = np.random.randn(batch_size, prediction_length).astype(np.float32) -```javascript -const batch = [ - [100, 110, 105, 115, 120, ...], // Series 1 - [50, 55, 52, 58, 60, ...], // Series 2 -]; +if "num_output_patches" in input_names: + inputs["num_output_patches"] = np.array(num_output_patches, dtype=np.int64) -const outputs = await forecaster(batch); -// Returns array of forecasts, one per input series +quantile_preds = session.run(None, inputs)[0] +print(quantile_preds.shape) ``` -## Performance - -- **Inference Time:** ~35-80ms per series (CPU, Node.js) -- **Speedup vs PyTorch:** 3-8x faster -- **Accuracy:** <1% error vs PyTorch reference - -## Technical Details - -### Preprocessing - -Chronos-2 uses automatic preprocessing: -1. **Repeat-padding:** Input is padded to be divisible by patch_size (16) -2. **Instance normalization:** Per-series z-score normalization -3. **arcsinh transformation:** Nonlinear transformation for better modeling - -All preprocessing is handled automatically by the pipeline. +## Tensor Interface -### Output Format - -The model outputs quantile forecasts: - -```typescript -interface Chronos2Output {{ - forecast: number[][]; // [prediction_length, num_quantiles] - quantile_levels: number[]; // The quantile levels for each column -}} -``` - -Extract specific quantiles: -```javascript -const median = output.forecast.map(row => row[1]); // 50th percentile -const lower = output.forecast.map(row => row[0]); // 10th percentile (lower bound) -const upper = output.forecast.map(row => row[2]); // 90th percentile (upper bound) -``` +- `context`: float32 tensor shaped `[batch, context_length]`. +- `group_ids`: int64 tensor shaped `[batch]`. Equal IDs allow time series in the same group to attend to each other. +- `attention_mask`: float32 tensor shaped `[batch, context_length]`, with `1` for observed context positions and `0` for masked positions. +- `future_covariates`: optional float32 tensor shaped `[batch, prediction_length]` when the model was exported with future covariates. +- `num_output_patches`: optional int64 scalar if the exported graph exposes it. Use the same value passed at export time. +- `quantile_preds`: float32 output shaped `[batch, num_quantiles, prediction_length]`. ## Limitations -- **Maximum context:** {chronos_config.get("context_length", 8192)} timesteps -- **Fixed prediction length:** 16 timesteps (for now; autoregressive unrolling coming soon) -- **Univariate only:** Single time series per input (multivariate support coming) +- Batch size is dynamic. +- Context length, future covariate length, and prediction length are fixed by the traced export. +- The default export traces `context_length=512` and `num_output_patches=4`, which gives a `{4 * output_patch_size}` step horizon for this model. +- This artifact exposes the model's tensor interface. Any serving API, preprocessing wrapper, or browser runtime integration is separate from the export. +- Validate PyTorch vs ONNX parity with `validate_chronos2_onnx.py` after export and before relying on a quantized model. ## Citation @@ -552,7 +570,6 @@ def generate_readme(output_dir: Path, model_id: str, quantized: bool = False): - [Chronos-2 Paper](https://arxiv.org/abs/2403.07815) - [Chronos GitHub](https://github.com/amazon-science/chronos-forecasting) -- [transformers.js Documentation](https://huggingface.co/docs/transformers.js) """ readme_path = output_dir / "README.md" @@ -635,6 +652,9 @@ def validate_onnx_export( onnx_path: Path, model_id: str, device: str = None, + include_future_covariates: bool = True, + context_length: int = 512, + num_output_patches: int = 4, rtol: float = 1e-3, atol: float = 1e-3, ) -> bool: @@ -645,6 +665,9 @@ def validate_onnx_export( onnx_path: Path to ONNX model model_id: Original model ID device: Device to use + include_future_covariates: Whether to validate the future covariate input + context_length: Context length expected by the exported model + num_output_patches: Number of output patches expected by the exported model rtol: Relative tolerance for comparison atol: Absolute tolerance for comparison @@ -670,17 +693,15 @@ def validate_onnx_export( logger.info(f"Loading ONNX model from {onnx_path}") providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"] ort_session = ort.InferenceSession(str(onnx_path), providers=providers) + input_names = {input_.name for input_ in ort_session.get_inputs()} - # Create test inputs batch_size = 4 - context_length = 256 - num_output_patches = 2 dummy_inputs = create_dummy_inputs( batch_size=batch_size, context_length=context_length, num_output_patches=num_output_patches, - include_future_covariates=False, + include_future_covariates=include_future_covariates, output_patch_size=model.chronos_config.output_patch_size, device=device, ) @@ -693,17 +714,22 @@ def validate_onnx_export( context=dummy_inputs["context"], group_ids=dummy_inputs["group_ids"], attention_mask=dummy_inputs["attention_mask"], - future_covariates=None, + future_covariates=dummy_inputs.get("future_covariates"), num_output_patches=dummy_inputs["num_output_patches"], ) - # Run ONNX inference (num_output_patches is fixed in the model, not an input) + # Run ONNX inference. Some PyTorch versions expose num_output_patches as a + # scalar graph input even when it is passed as a Python int during tracing. logger.info("Running ONNX inference...") ort_inputs = { "context": dummy_inputs["context"].cpu().numpy(), "group_ids": dummy_inputs["group_ids"].cpu().numpy(), "attention_mask": dummy_inputs["attention_mask"].cpu().numpy(), } + if "future_covariates" in input_names: + ort_inputs["future_covariates"] = dummy_inputs["future_covariates"].cpu().numpy() + if "num_output_patches" in input_names: + ort_inputs["num_output_patches"] = np.array(dummy_inputs["num_output_patches"], dtype=np.int64) onnx_output = ort_session.run(None, ort_inputs)[0] @@ -739,11 +765,23 @@ def main(): parser.add_argument( "--model_id", type=str, - default="amazon/chronos-2-small", - help="HuggingFace model ID or local path (e.g., 'amazon/chronos-2-small')", + default="amazon/chronos-2", + help="HuggingFace model ID or local path (e.g., 'amazon/chronos-2')", ) parser.add_argument("--output_dir", type=str, default="./chronos2-onnx", help="Output directory for ONNX model") parser.add_argument("--opset_version", type=int, default=17, help="ONNX opset version (default: 17)") + parser.add_argument( + "--context_length", + type=int, + default=512, + help="Fixed context length to export (default: 512; batch size remains dynamic)", + ) + parser.add_argument( + "--num_output_patches", + type=int, + default=4, + help="Number of output patches to export (default: 4; with Chronos-2 output_patch_size=16 this gives 64 steps)", + ) parser.add_argument("--fp16", action="store_true", help="Export model in FP16 precision") parser.add_argument( "--validate", action="store_true", help="Validate ONNX export by comparing with PyTorch outputs" @@ -754,6 +792,21 @@ def main(): parser.add_argument( "--device", type=str, default=None, choices=["cpu", "cuda"], help="Device to use (default: auto-detect)" ) + parser.add_argument( + "--no_fix_onnx", + action="store_true", + help="Skip the mandatory post-export ONNX graph fix pass (not recommended; raw graph may not load in ONNX Runtime)", + ) + parser.add_argument( + "--keep_raw_onnx", + action="store_true", + help="Keep model_raw.onnx after writing the fixed model.onnx", + ) + parser.add_argument( + "--setup_transformersjs", + action="store_true", + help="Also create the legacy onnx/ symlink layout used by the original browser demo.", + ) parser.add_argument("--quantize", action="store_true", help="Quantize the model to INT8 after export") parser.add_argument( "--push_to_hub", @@ -773,15 +826,27 @@ def main(): logger.info("Chronos-2 ONNX Export Pipeline") logger.info("=" * 60 + "\n") - onnx_path = export_to_onnx( + raw_or_final_onnx_path = export_to_onnx( model_id=args.model_id, output_dir=output_dir, opset_version=args.opset_version, use_fp16=args.fp16, include_future_covariates=not args.no_future_covariates, + context_length=args.context_length, + num_output_patches=args.num_output_patches, device=args.device, + onnx_filename="model.onnx" if args.no_fix_onnx else "model_raw.onnx", ) + if args.no_fix_onnx: + onnx_path = raw_or_final_onnx_path + logger.warning("Skipping ONNX fix pass; the raw graph may not load in ONNX Runtime") + else: + onnx_path = fix_onnx_export(raw_or_final_onnx_path, output_dir / "model.onnx") + if not args.keep_raw_onnx: + raw_or_final_onnx_path.unlink(missing_ok=True) + logger.info(f"Removed intermediate raw ONNX model: {raw_or_final_onnx_path}") + # Validate if requested if args.validate: logger.info("\n" + "=" * 60) @@ -792,6 +857,9 @@ def main(): onnx_path=onnx_path, model_id=args.model_id, device=args.device, + include_future_covariates=not args.no_future_covariates, + context_length=args.context_length, + num_output_patches=args.num_output_patches, ) if not validation_passed: @@ -807,12 +875,13 @@ def main(): quantized_path = quantize_model(onnx_path) - # Setup transformers.js directory structure - logger.info("\n" + "=" * 60) - logger.info("transformers.js Setup") - logger.info("=" * 60 + "\n") + if args.setup_transformersjs: + # Setup transformers.js directory structure + logger.info("\n" + "=" * 60) + logger.info("transformers.js Setup") + logger.info("=" * 60 + "\n") - setup_transformersjs_structure(output_dir) + setup_transformersjs_structure(output_dir) # Generate README logger.info("\n" + "=" * 60) diff --git a/scripts/onnx/fix_onnx_model.py b/scripts/onnx/fix_onnx_model.py index 04fbb57b..37343565 100644 --- a/scripts/onnx/fix_onnx_model.py +++ b/scripts/onnx/fix_onnx_model.py @@ -43,9 +43,48 @@ def make_prediction_length_dynamic(model: onnx.ModelProto, dim_name: str = "pred return model -def fix_gather_indices(model_path: str, output_path: str, make_dynamic: bool = True): +def _input_dim(model: onnx.ModelProto, input_name: str, axis: int) -> int | None: + for input_ in model.graph.input: + if input_.name != input_name or not input_.type.tensor_type.HasField("shape"): + continue + shape = input_.type.tensor_type.shape + if len(shape.dim) <= axis: + return None + dim = shape.dim[axis] + if dim.HasField("dim_value"): + return dim.dim_value + return None + + +def make_prediction_length_static(model: onnx.ModelProto, prediction_length: int | None = None): + """ + Make the output prediction_length metadata static when the traced length is known. + """ + if prediction_length is None: + prediction_length = _input_dim(model, "future_covariates", 1) + + if prediction_length is None: + return model + + print(f"\nSetting prediction_length output metadata to static value {prediction_length}...") + for output in model.graph.output: + if output.type.tensor_type.HasField("shape"): + shape = output.type.tensor_type.shape + if len(shape.dim) == 3: + shape.dim[2].Clear() + shape.dim[2].dim_value = prediction_length + + return model + + +def fix_gather_indices( + model_path: str, + output_path: str, + make_dynamic: bool = False, + prediction_length: int | None = None, +): """ - Fix Gather operation index type issues in ONNX model and optionally make prediction_length dynamic. + Fix Gather operation index type issues in ONNX model. The indices may be represented as float tensors in the graph but Gather requires int64. This function inserts Cast operations to convert float @@ -54,7 +93,12 @@ def fix_gather_indices(model_path: str, output_path: str, make_dynamic: bool = T Args: model_path: Path to input ONNX model output_path: Path to save fixed ONNX model - make_dynamic: If True, also make the prediction_length dimension dynamic + make_dynamic: If True, also mark the output prediction_length dimension + dynamic in model metadata. This does not change the traced graph, whose + executable horizon remains fixed by export. + prediction_length: Optional fixed output horizon to write into output + metadata when make_dynamic is False. If omitted, the fixer infers it + from the fixed future_covariates input when available. """ print(f"Loading ONNX model from {model_path}") model = onnx.load(model_path) @@ -155,9 +199,10 @@ def fix_gather_indices(model_path: str, output_path: str, make_dynamic: bool = T print(f"Added {concat_cast_count} Cast operations before Concat nodes") - # Make prediction_length dimension dynamic if make_dynamic: model = make_prediction_length_dynamic(model) + else: + model = make_prediction_length_static(model, prediction_length=prediction_length) # Validate and save print("\nValidating fixed model...") @@ -181,11 +226,30 @@ def fix_gather_indices(model_path: str, output_path: str, make_dynamic: bool = T parser = argparse.ArgumentParser(description="Fix ONNX model type issues") parser.add_argument("input", help="Input ONNX model path") parser.add_argument("output", help="Output ONNX model path") + parser.add_argument( + "--dynamic_prediction_length", + action="store_true", + help=( + "Mark the output prediction_length dimension dynamic in metadata. " + "This does not make the traced ONNX graph accept different horizons." + ), + ) + parser.add_argument( + "--prediction_length", + type=int, + default=None, + help="Fixed output prediction length to write into output metadata.", + ) args = parser.parse_args() try: - fix_gather_indices(args.input, args.output) + fix_gather_indices( + args.input, + args.output, + make_dynamic=args.dynamic_prediction_length, + prediction_length=args.prediction_length, + ) print("\n✓ Model fixed successfully!") sys.exit(0) except Exception as e: diff --git a/scripts/onnx/quantize_chronos2.py b/scripts/onnx/quantize_chronos2.py index 343f53fe..22f26afe 100644 --- a/scripts/onnx/quantize_chronos2.py +++ b/scripts/onnx/quantize_chronos2.py @@ -18,6 +18,7 @@ """ import argparse +import json import logging from pathlib import Path @@ -27,6 +28,59 @@ logger = logging.getLogger(__name__) +def _static_dim(value, default: int) -> int: + return value if isinstance(value, int) and value > 0 else default + + +def _input_dim(session, input_name: str, axis: int, default: int) -> int: + for input_ in session.get_inputs(): + if input_.name == input_name and len(input_.shape) > axis: + return _static_dim(input_.shape[axis], default) + return default + + +def _output_dim(session, axis: int, default: int) -> int: + outputs = session.get_outputs() + if outputs and len(outputs[0].shape) > axis: + return _static_dim(outputs[0].shape[axis], default) + return default + + +def _infer_num_output_patches(model_path: str, prediction_length: int, default: int = 4) -> int: + config_path = Path(model_path).parent / "config.json" + if not config_path.exists(): + return default + + with open(config_path) as f: + config = json.load(f) + + output_patch_size = config.get("chronos_config", {}).get("output_patch_size") + if not isinstance(output_patch_size, int) or output_patch_size <= 0: + return default + + return max(1, prediction_length // output_patch_size) + + +def _make_test_inputs( + input_names: set[str], + batch_size: int, + context_length: int, + prediction_length: int, + num_output_patches: int, +) -> dict[str, np.ndarray]: + inputs = { + "context": np.random.randn(batch_size, context_length).astype(np.float32), + "group_ids": np.arange(batch_size, dtype=np.int64), + "attention_mask": np.ones((batch_size, context_length), dtype=np.float32), + } + if "future_covariates" in input_names: + inputs["future_covariates"] = np.random.randn(batch_size, prediction_length).astype(np.float32) + if "num_output_patches" in input_names: + inputs["num_output_patches"] = np.array(num_output_patches, dtype=np.int64) + + return inputs + + def dynamic_quantization(model_path: str, output_path: str): """ Apply dynamic quantization to the ONNX model. @@ -78,10 +132,17 @@ def static_quantization(model_path: str, output_path: str, calibration_data_path - More complex setup - Potential accuracy loss if calibration data not representative """ + import onnxruntime as ort from onnxruntime.quantization import quantize_static, QuantType, CalibrationDataReader logger.info(f"Loading model from {model_path}") + session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) + input_names = {input_.name for input_ in session.get_inputs()} + context_length = _input_dim(session, "context", 1, 512) + prediction_length = _input_dim(session, "future_covariates", 1, _output_dim(session, 2, 64)) + num_output_patches = _infer_num_output_patches(model_path, prediction_length) + # Create calibration data reader if calibration_data_path: logger.info(f"Loading calibration data from {calibration_data_path}") @@ -95,24 +156,20 @@ def __init__(self, num_samples=100): self.num_samples = num_samples self.current_sample = 0 self.batch_size = 1 - self.context_length = 512 def get_next(self): if self.current_sample >= self.num_samples: return None - # Generate synthetic time series data - context = np.random.randn(self.batch_size, self.context_length).astype(np.float32) - group_ids = np.array([0], dtype=np.int64) - attention_mask = np.ones((self.batch_size, self.context_length), dtype=np.float32) - self.current_sample += 1 - return { - "context": context, - "group_ids": group_ids, - "attention_mask": attention_mask, - } + return _make_test_inputs( + input_names=input_names, + batch_size=self.batch_size, + context_length=context_length, + prediction_length=prediction_length, + num_output_patches=num_output_patches, + ) calibration_data_reader = SyntheticCalibrationDataReader() @@ -157,16 +214,21 @@ def validate_quantized_model(model_path: str): # Load model session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) + input_names = {input_.name for input_ in session.get_inputs()} # Create test input batch_size = 1 - context_length = 256 - - inputs = { - "context": np.random.randn(batch_size, context_length).astype(np.float32), - "group_ids": np.array([0], dtype=np.int64), - "attention_mask": np.ones((batch_size, context_length), dtype=np.float32), - } + context_length = _input_dim(session, "context", 1, 512) + prediction_length = _input_dim(session, "future_covariates", 1, _output_dim(session, 2, 64)) + num_output_patches = _infer_num_output_patches(model_path, prediction_length) + + inputs = _make_test_inputs( + input_names=input_names, + batch_size=batch_size, + context_length=context_length, + prediction_length=prediction_length, + num_output_patches=num_output_patches, + ) # Run inference logger.info(" Running test inference...") diff --git a/scripts/onnx/validate_chronos2_onnx.py b/scripts/onnx/validate_chronos2_onnx.py new file mode 100644 index 00000000..f4a46c7d --- /dev/null +++ b/scripts/onnx/validate_chronos2_onnx.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Validate a Chronos-2 ONNX export against the PyTorch model. + +This script exercises the tensor-level ONNX interface used by +export_chronos2_to_onnx.py. It intentionally validates future covariates when +the ONNX graph exposes the `future_covariates` input. +""" + +import argparse +import json +import sys +from pathlib import Path + +import numpy as np +import torch + +SCRIPT_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str(SCRIPT_DIR)) + +from export_chronos2_to_onnx import Chronos2ONNXWrapper # noqa: E402 +from chronos import Chronos2Pipeline # noqa: E402 + + +def make_context(batch_size: int, context_length: int, *, missing: bool = False) -> tuple[torch.Tensor, torch.Tensor]: + t = torch.linspace(0, 10, context_length, dtype=torch.float32) + rows = [] + for i in range(batch_size): + trend = (i + 1) * 0.02 * t + seasonal = torch.sin(t * (1.3 + i * 0.2)) + 0.4 * torch.cos(t * (0.7 + i * 0.1)) + rows.append(trend + seasonal + 0.02 * torch.randn(context_length)) + + context = torch.stack(rows) + attention_mask = torch.ones(batch_size, context_length, dtype=torch.float32) + + if missing: + context[0, 30:46] = torch.nan + attention_mask[0, 30:46] = 0.0 + if batch_size > 1: + context[-1, 211:229] = torch.nan + attention_mask[-1, 211:229] = 0.0 + + return context, attention_mask + + +def make_future_covariates(batch_size: int, future_length: int, pattern: str, *, missing: bool = False) -> torch.Tensor: + t = torch.linspace(0, 1, future_length, dtype=torch.float32) + + if pattern == "zeros": + future_covariates = torch.zeros(batch_size, future_length, dtype=torch.float32) + elif pattern == "sin": + future_covariates = torch.stack([torch.sin((i + 1) * torch.pi * t) for i in range(batch_size)]) + elif pattern == "cos": + future_covariates = torch.stack([torch.cos((i + 1) * torch.pi * t) for i in range(batch_size)]) + elif pattern == "random": + future_covariates = torch.randn(batch_size, future_length, dtype=torch.float32) * 0.5 + else: + raise ValueError(f"Unknown covariate pattern: {pattern}") + + if missing: + future_covariates[0, 8:16] = torch.nan + if batch_size > 2: + future_covariates[2, 40:48] = torch.nan + + return future_covariates + + +def run_case( + *, + name: str, + wrapped_model: Chronos2ONNXWrapper, + ort_session, + input_names: set[str], + batch_size: int, + context_length: int, + num_output_patches: int, + output_patch_size: int, + group_ids: list[int], + covariate_pattern: str, + device: str, + missing_context: bool = False, + missing_future_covariates: bool = False, + rtol: float, + atol: float, +) -> dict: + context, attention_mask = make_context(batch_size, context_length, missing=missing_context) + context = context.to(device) + attention_mask = attention_mask.to(device) + group_ids_tensor = torch.tensor(group_ids, dtype=torch.long, device=device) + + include_future_covariates = "future_covariates" in input_names + future_covariates = None + if include_future_covariates: + future_covariates = make_future_covariates( + batch_size, + num_output_patches * output_patch_size, + covariate_pattern, + missing=missing_future_covariates, + ).to(device) + + with torch.no_grad(): + pytorch_output = wrapped_model( + context=context, + group_ids=group_ids_tensor, + attention_mask=attention_mask, + future_covariates=future_covariates, + num_output_patches=num_output_patches, + ) + + ort_inputs = { + "context": context.cpu().numpy(), + "group_ids": group_ids_tensor.cpu().numpy(), + "attention_mask": attention_mask.cpu().numpy(), + } + if include_future_covariates: + ort_inputs["future_covariates"] = future_covariates.cpu().numpy() + if "num_output_patches" in input_names: + ort_inputs["num_output_patches"] = np.array(num_output_patches, dtype=np.int64) + + onnx_output = ort_session.run(None, ort_inputs)[0] + pytorch_output_np = pytorch_output.cpu().numpy() + abs_diff = np.abs(pytorch_output_np - onnx_output) + + return { + "name": name, + "batch_size": batch_size, + "group_ids": group_ids, + "covariate_pattern": covariate_pattern if include_future_covariates else None, + "missing_context": missing_context, + "missing_future_covariates": missing_future_covariates if include_future_covariates else None, + "pytorch_shape": list(pytorch_output_np.shape), + "onnx_shape": list(onnx_output.shape), + "max_abs_diff": float(np.nanmax(abs_diff)), + "mean_abs_diff": float(np.nanmean(abs_diff)), + "allclose": bool(np.allclose(pytorch_output_np, onnx_output, rtol=rtol, atol=atol, equal_nan=True)), + } + + +def main() -> int: + parser = argparse.ArgumentParser(description="Validate Chronos-2 ONNX parity against PyTorch") + parser.add_argument("--model_id", type=str, default="amazon/chronos-2", help="HuggingFace model ID or local path") + parser.add_argument("--onnx_path", type=str, required=True, help="Path to the fixed ONNX model") + parser.add_argument("--context_length", type=int, default=512, help="Context length used during export") + parser.add_argument("--num_output_patches", type=int, default=4, help="Number of output patches used during export") + parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "cuda"], help="PyTorch device") + parser.add_argument("--rtol", type=float, default=1e-4, help="Relative tolerance") + parser.add_argument("--atol", type=float, default=1e-4, help="Absolute tolerance") + parser.add_argument("--report_path", type=str, default=None, help="Optional JSON report output path") + args = parser.parse_args() + + torch.manual_seed(123) + np.random.seed(123) + + pipeline = Chronos2Pipeline.from_pretrained(args.model_id, device_map=args.device) + model = pipeline.model.eval() + wrapped_model = Chronos2ONNXWrapper(model).eval() + + import onnxruntime as ort + + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if args.device == "cuda" else ["CPUExecutionProvider"] + ort_session = ort.InferenceSession(args.onnx_path, providers=providers) + input_names = {input_.name for input_ in ort_session.get_inputs()} + + output_patch_size = model.chronos_config.output_patch_size + cases = [ + dict(name="batch1_zeros", batch_size=1, group_ids=[0], covariate_pattern="zeros"), + dict(name="batch2_shared_group_sin", batch_size=2, group_ids=[0, 0], covariate_pattern="sin"), + dict(name="batch3_mixed_groups_cos", batch_size=3, group_ids=[0, 1, 0], covariate_pattern="cos"), + dict(name="batch4_distinct_random", batch_size=4, group_ids=[0, 1, 2, 3], covariate_pattern="random"), + dict( + name="batch3_missing_context_sin", + batch_size=3, + group_ids=[0, 1, 0], + covariate_pattern="sin", + missing_context=True, + ), + dict( + name="batch3_missing_future_cos", + batch_size=3, + group_ids=[0, 1, 0], + covariate_pattern="cos", + missing_future_covariates=True, + ), + ] + + results = [ + run_case( + wrapped_model=wrapped_model, + ort_session=ort_session, + input_names=input_names, + context_length=args.context_length, + num_output_patches=args.num_output_patches, + output_patch_size=output_patch_size, + device=args.device, + rtol=args.rtol, + atol=args.atol, + **case, + ) + for case in cases + ] + + report = { + "model_id": args.model_id, + "onnx_path": args.onnx_path, + "providers": ort_session.get_providers(), + "inputs": [(input_.name, input_.type, input_.shape) for input_ in ort_session.get_inputs()], + "outputs": [(output.name, output.type, output.shape) for output in ort_session.get_outputs()], + "rtol": args.rtol, + "atol": args.atol, + "all_cases_passed": all(result["allclose"] for result in results), + "cases": results, + } + + report_json = json.dumps(report, indent=2) + print(report_json) + + if args.report_path: + Path(args.report_path).write_text(report_json) + + return 0 if report["all_cases_passed"] else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/chronos/chronos_bolt.py b/src/chronos/chronos_bolt.py index 743ec06b..7be64bff 100644 --- a/src/chronos/chronos_bolt.py +++ b/src/chronos/chronos_bolt.py @@ -64,7 +64,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: padding = torch.full(size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device) x = torch.concat((padding, x), dim=-1) - x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride) + if self.patch_stride == self.patch_size: + x = x.reshape(*x.shape[:-1], -1, self.patch_size) + else: + x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride) return x @@ -78,14 +81,22 @@ def __init__(self, eps: float = 1e-5, use_arcsinh: bool = False) -> None: self.eps = eps self.use_arcsinh = use_arcsinh + @staticmethod + def _nanmean(x: torch.Tensor, dim: int, keepdim: bool, empty_value: float) -> torch.Tensor: + finite_mask = torch.isnan(x).logical_not() + finite_x = torch.where(finite_mask, x, torch.zeros_like(x)) + count = finite_mask.to(x.dtype).sum(dim=dim, keepdim=keepdim) + mean = finite_x.sum(dim=dim, keepdim=keepdim) / count.clamp_min(1) + return torch.where(count > 0, mean, torch.full_like(mean, empty_value)) + def forward( self, x: torch.Tensor, loc_scale: tuple[torch.Tensor, torch.Tensor] | None = None ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: orig_dtype = x.dtype x = x.to(torch.float32) if loc_scale is None: - loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=0.0) - scale = torch.nan_to_num((x - loc).square().nanmean(dim=-1, keepdim=True).sqrt(), nan=1.0) + loc = self._nanmean(x, dim=-1, keepdim=True, empty_value=0.0) + scale = self._nanmean((x - loc).square(), dim=-1, keepdim=True, empty_value=1.0).sqrt() scale = torch.where(scale == 0, self.eps, scale) else: loc, scale = loc_scale