Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions src/chronos/chronos2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,10 @@ def _validate_input(
)

def _prepare_patched_context(
self, context: torch.Tensor, context_mask: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
self, context: torch.Tensor, context_mask: torch.Tensor | None = None, return_minmax: bool = False
) -> tuple[
torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor] | None
]:
context_mask = (
context_mask.to(context.dtype)
if context_mask is not None
Expand All @@ -386,6 +388,14 @@ def _prepare_patched_context(
context_mask = context_mask[..., -self.chronos_config.context_length :]

# scaling
context_minmax = None
if return_minmax:
context_min = torch.amin(torch.nan_to_num(context, nan=float("inf")), dim=-1, keepdim=True)
context_min = torch.nan_to_num(context_min, posinf=0.0)
context_max = torch.amax(torch.nan_to_num(context, nan=float("-inf")), dim=-1, keepdim=True)
context_max = torch.nan_to_num(context_max, neginf=0.0)
context_minmax = context_min, context_max

context, loc_scale = self.instance_norm(context)

# scaling is done in 32-bit precision, then the context is moved to model's dtype
Expand Down Expand Up @@ -420,7 +430,7 @@ def _prepare_patched_context(
# concat time encoding, context and mask along the last (feature) dim
patched_context = torch.cat([context_time_enc, patched_context, patched_mask], dim=-1)

return patched_context, attention_mask, loc_scale
return patched_context, attention_mask, loc_scale, context_minmax

def _prepare_patched_future(
self,
Expand Down Expand Up @@ -558,6 +568,7 @@ def encode(
future_target: torch.Tensor | None = None,
future_target_mask: torch.Tensor | None = None,
output_attentions: bool = False,
return_minmax: bool = False,
):
self._validate_input(
context=context,
Expand All @@ -571,8 +582,8 @@ def encode(
)

batch_size = context.shape[0]
patched_context, attention_mask, loc_scale = self._prepare_patched_context(
context=context, context_mask=context_mask
patched_context, attention_mask, loc_scale, context_minmax = self._prepare_patched_context(
context=context, context_mask=context_mask, return_minmax=return_minmax
)
num_context_patches = attention_mask.shape[-1]

Expand Down Expand Up @@ -613,7 +624,7 @@ def encode(
group_ids=group_ids,
output_attentions=output_attentions,
)
return encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches
return encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches, context_minmax

def forward(
self,
Expand All @@ -626,6 +637,7 @@ def forward(
future_target: torch.Tensor | None = None,
future_target_mask: torch.Tensor | None = None,
output_attentions: bool = False,
clip_factor: float | None = None,
) -> Chronos2Output:
"""Forward pass of the Chronos2 model.

Expand Down Expand Up @@ -694,7 +706,7 @@ def forward(
- enc_group_self_attn_weights: Group self attention weights, if output_attentions=True
"""
batch_size = context.shape[0]
encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches = self.encode(
encoder_outputs, loc_scale, patched_future_covariates_mask, num_context_patches, context_minmax = self.encode(
context=context,
context_mask=context_mask,
group_ids=group_ids,
Expand All @@ -704,6 +716,7 @@ def forward(
future_target=future_target,
future_target_mask=future_target_mask,
output_attentions=output_attentions,
return_minmax=clip_factor is not None,
)
hidden_states: torch.Tensor = encoder_outputs[0]
assert hidden_states.shape == (batch_size, num_context_patches + 1 + num_output_patches, self.model_dim)
Expand Down Expand Up @@ -741,6 +754,13 @@ def forward(
h=num_output_patches * self.chronos_config.output_patch_size,
)
quantile_preds = self.instance_norm.inverse(quantile_preds, loc_scale)

if clip_factor is not None:
assert context_minmax is not None
clamp_min = context_minmax[0] - clip_factor * loc_scale[1]
clamp_max = context_minmax[1] + clip_factor * loc_scale[1]
quantile_preds = quantile_preds.clamp(min=clamp_min, max=clamp_max)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this PR also take care of the issue of exploding predictions for all-NaN context?

import numpy as np
from chronos import Chronos2Pipeline

model = Chronos2Pipeline.from_pretrained("amazon/chronos-2")
model.predict(np.full((1, 1, 50), float("nan")), prediction_length=5)
# [tensor([[[-8.6131e+06, -1.0066e+07, -1.0386e+07, -1.0977e+07, -1.1329e+07],
#      [-1.2380e+05, -1.1830e+05, -1.5210e+05, -1.5608e+05, -1.5930e+05],
#      [-3.2544e+04, -2.7985e+04, -3.9638e+04, -4.4504e+04, -5.1272e+04],
#      ...

I have a hunch that this is also somehow related to the scale computation.


quantile_preds = rearrange(
quantile_preds,
"b (q h) -> b q h",
Expand Down
21 changes: 19 additions & 2 deletions src/chronos/chronos2/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def _autoregressive_unroll_for_long_horizon(
unrolled_quantiles: torch.Tensor,
unrolled_sample_weights: torch.Tensor,
num_output_patches: int,
clip_factor: float | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Get unrolled_quantiles from prediction and append it to the expanded context
prediction_unrolled = interpolate_quantiles(
Expand All @@ -439,6 +440,7 @@ def _autoregressive_unroll_for_long_horizon(
else None,
group_ids=rearrange(group_ids, "b n -> (b n)"),
num_output_patches=num_output_patches,
clip_factor=clip_factor,
)
# Reshape predictions from (batch * n_paths, n_quantiles, length) to (batch, n_paths * n_quantiles, length)
prediction = rearrange(prediction, "(b n) q h -> b (n q) h", n=n_paths)
Expand All @@ -463,6 +465,7 @@ def predict(
context_length: int | None = None,
cross_learning: bool = False,
limit_prediction_length: bool = False,
clip_factor: float | None = None,
**kwargs,
) -> list[torch.Tensor]:
"""
Expand Down Expand Up @@ -647,8 +650,14 @@ def predict(
prediction_length=prediction_length,
max_output_patches=max_output_patches,
target_idx_ranges=batch_target_idx_ranges,
clip_factor=clip_factor,
)
all_predictions.extend(batch_prediction)

# Remove floating point noise around integers
for item in batch_prediction:
item = torch.where(torch.abs(item - item.round()) < 1e-5, item.round(), item)
all_predictions.append(item)

after_batch_callback()

return all_predictions
Expand All @@ -662,6 +671,7 @@ def _predict_batch(
prediction_length: int,
max_output_patches: int,
target_idx_ranges: list[tuple[int, int]],
clip_factor: float | None = None,
) -> list[torch.Tensor]:
context = context.to(device=self.model.device, dtype=torch.float32)
group_ids = group_ids.to(device=self.model.device)
Expand All @@ -682,6 +692,7 @@ def get_num_output_patches(remaining_horizon: int):
group_ids=group_ids,
future_covariates=future_covariates,
num_output_patches=get_num_output_patches(remaining),
clip_factor=clip_factor,
)
predictions.append(prediction)
remaining -= prediction.shape[-1]
Expand All @@ -707,6 +718,7 @@ def get_num_output_patches(remaining_horizon: int):
unrolled_quantiles=unrolled_quantiles_tensor,
unrolled_sample_weights=unrolled_sample_weights,
num_output_patches=get_num_output_patches(remaining),
clip_factor=clip_factor,
)
predictions.append(prediction)
remaining -= prediction.shape[-1]
Expand All @@ -723,6 +735,7 @@ def _predict_step(
group_ids: torch.Tensor,
future_covariates: torch.Tensor | None,
num_output_patches: int,
clip_factor: float | None = None,
) -> torch.Tensor:
kwargs = {}
if future_covariates is not None:
Expand All @@ -741,7 +754,11 @@ def _predict_step(
kwargs["future_covariates"] = future_covariates
with torch.no_grad():
prediction: torch.Tensor = self.model(
context=context, group_ids=group_ids, num_output_patches=num_output_patches, **kwargs
context=context,
group_ids=group_ids,
num_output_patches=num_output_patches,
clip_factor=clip_factor,
**kwargs,
).quantile_preds.to(context)

return prediction
Expand Down
Loading