From 17eeafcaeb7d4dfb135d3ddcfebf3661533e33dd Mon Sep 17 00:00:00 2001 From: brettbj Date: Tue, 9 Jun 2026 19:07:05 -0500 Subject: [PATCH] fix: don't crash the quantile token loss when no kernel is configured aggregate_first/loss_first read cfg.quantile_token_loss.kernel.type unconditionally, so the shipped config/main.yaml (quantile_token_loss without a kernel block) raised ConfigAttributeError on the first training step; a kernel block without 'factor' likewise left self.kernel_factor unset and raised AttributeError later. Resolve kernel_type (default 'linear') and kernel_factor (default 1.0) once in __init__ and reference only those attributes in the loss functions. Co-Authored-By: Claude Fable 5 --- src/cotorra/loss.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/cotorra/loss.py b/src/cotorra/loss.py index 362c108..78b6327 100644 --- a/src/cotorra/loss.py +++ b/src/cotorra/loss.py @@ -74,13 +74,9 @@ def __init__(self, cfg=None, tkzr_cfg=None): self.loss_function = self.loss_functions[self.cfg.quantile_token_loss.loss_type] else: self.loss_function = t.nn.MSELoss - if "kernel" in self.cfg.quantile_token_loss: - if "type" in self.cfg.quantile_token_loss.kernel: - self.kernel_type = self.cfg.quantile_token_loss.kernel.type - else: - self.kernel_type = "linear" - if "factor" in self.cfg.quantile_token_loss.kernel: - self.kernel_factor = self.cfg.quantile_token_loss.kernel.factor + kernel_cfg = self.cfg.quantile_token_loss.get("kernel", {}) + self.kernel_type = kernel_cfg.get("type", "linear") + self.kernel_factor = kernel_cfg.get("factor", 1.0) def quantile_token_loss(self, outputs, labels, **kwargs): @@ -93,7 +89,7 @@ def aggregate_first(cat_logits, cat_labels, i): self.label_to_q[self.label_to_cat == i] ).to(device=cat_logits.device) cat_true = self.label_to_q.to(device=cat_labels.device)[cat_labels] - if self.cfg.quantile_token_loss.kernel.type in self.kernels: + if self.kernel_type in self.kernels: kernel = self.kernels[self.kernel_type] cat_preds = kernel(cat_preds, self.kernel_factor) cat_true = kernel(cat_true, self.kernel_factor) @@ -103,7 +99,7 @@ def loss_first(cat_logits, cat_labels, i): values =( self.label_to_q[self.label_to_cat == i] ).to(device=cat_logits.device) - if self.cfg.quantile_token_loss.kernel.type in self.kernels: + if self.kernel_type in self.kernels: kernel = self.kernels[self.kernel_type] values = kernel(values, self.kernel_factor) cat_true = kernel(cat_true, self.kernel_factor)