Skip to content

[ONNX] Add Chronos-2 covariate export validation#508

Open
Jameswlepage wants to merge 4 commits into
amazon-science:mainfrom
Jameswlepage:chronos2-onnx-covariates
Open

[ONNX] Add Chronos-2 covariate export validation#508
Jameswlepage wants to merge 4 commits into
amazon-science:mainfrom
Jameswlepage:chronos2-onnx-covariates

Conversation

@Jameswlepage

Copy link
Copy Markdown

Summary

This builds on and supersedes the initial Chronos-2 ONNX export work in #359.

It updates the Chronos-2 ONNX export path so users can export a real Chronos-2 model with future covariates, repair the exported ONNX graph, and validate ONNX Runtime output against PyTorch.

Changes

  • Export amazon/chronos-2 by default with future covariates enabled.
  • Use the legacy TorchScript ONNX exporter explicitly with dynamo=False for compatibility with current PyTorch defaults.
  • Keep batch size dynamic while documenting fixed context and forecast lengths.
  • Run the ONNX graph repair pass automatically after raw export.
  • Repair fix_onnx_model.py so it fixes Gather index dtype issues without incorrectly advertising a dynamic horizon.
  • Add a standalone PyTorch-vs-ONNX parity harness for grouped batches, future covariates, and missing values.
  • Update quantized-model validation to infer the exported fixed shapes from the ONNX graph.
  • Add ONNX export documentation with supported inputs, outputs, and limitations.
  • Adjust Chronos patching and missing-value normalization paths so the Chronos-2 model exports cleanly.

ONNX Contract

The default export has:

  • context: float32[batch_size, 512]
  • group_ids: int64[batch_size]
  • attention_mask: float32[batch_size, 512]
  • future_covariates: float32[batch_size, 64]
  • num_output_patches: int64[] when exposed by the exporter
  • quantile_preds: float32[batch_size, 21, 64]

Batch size is dynamic. Context length, future covariate length, and prediction length are fixed by the traced export.

Validation

Local checks run:

  • python -m py_compile scripts/onnx/export_chronos2_to_onnx.py scripts/onnx/fix_onnx_model.py scripts/onnx/quantize_chronos2.py scripts/onnx/validate_chronos2_onnx.py src/chronos/chronos_bolt.py
  • python scripts/onnx/export_chronos2_to_onnx.py --model_id test/dummy-chronos2-model --output_dir ... --device cpu --validate
  • python scripts/onnx/validate_chronos2_onnx.py --model_id test/dummy-chronos2-model --onnx_path .../model.onnx --context_length 512 --num_output_patches 4
  • python scripts/onnx/quantize_chronos2.py --input .../model.onnx --output .../model_quantized.onnx --mode dynamic --validate

A real amazon/chronos-2 ONNX covariate export and parity reports are published here:

https://huggingface.co/TSFM-ai/chronos-2-onnx

The real export parity report covers batch sizes 1 through 5, shared and distinct group_ids, several future-covariate patterns, missing context values, and missing future-covariate values.

@Jameswlepage

Copy link
Copy Markdown
Author

@kashif I opened this as a follow-up to #359 with the Chronos-2 covariate export fixes, ONNX graph repair flow, and PyTorch-vs-ONNX parity validation. Would appreciate your review when you have a chance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants