Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/cotorra/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def extract(
print(f"\n[green]✓[/green] Extraction completed in {t1 - t0:.2f}s.")
for split in extractor.loader.splits:
a = "-all" if all_times else ""
output = extractor.processed_data_home / f"features{a}-{split}.parquet"
output = extractor.processed_data_home / f"features{a}-{split}-{extractor.cfg.run_name}.parquet"
print(f" Output: {output}")


Expand Down Expand Up @@ -245,6 +245,10 @@ def rep_based_score(
"--verbose", "-v", help="Verbose logging for collate", is_flag=True
),
] = False,
ridge: Annotated[
bool,
typer.Option("--ridge", help="Use ridge regression (alpha=0.1) instead of LightGBM", is_flag=True),
] = False,
):
"""
Generate rep-based scores for the token-based outcomes of interest.
Expand All @@ -253,14 +257,15 @@ def rep_based_score(
with console.status("[bold green]Rep-based scoring on held-out data..."):
t0 = time.perf_counter()
scorer = RepBasedScorer(
main_cfg=main_config, processed_data_home=processed_data_home
main_cfg=main_config, processed_data_home=processed_data_home, ridge=ridge
)
scorer.save_all(verbose=verbose)
t1 = time.perf_counter()
print(f"\n[green]✓[/green] Rep-based scoring completed in {t1 - t0:.2f}s.")
scorer_tag = "ridge" if ridge else "rep-based"
out_path = (
scorer.processed_data_home
/ f"scores-rep-based-{scorer.cfg.run_name}.parquet"
/ f"scores-{scorer_tag}-{scorer.cfg.run_name}.parquet"
)
print(f" Scores: [cyan]{out_path}[/cyan]")

Expand Down
3 changes: 2 additions & 1 deletion src/cotorra/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ def extract(self, all_times: bool = False):
lambda batch: self.extract_final(batch, all_times=all_times),
batched=True,
batch_size=self.cfg.get("extract", {}).get("batch_size", 8),
load_from_cache_file=False,
).to_parquet(
self.processed_data_home / f"features{a}-{split}{index}.parquet"
self.processed_data_home / f"features{a}-{split}-{self.cfg.run_name}{index}.parquet"
)


Expand Down
68 changes: 56 additions & 12 deletions src/cotorra/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@ def __init__(self, cfg=None, tkzr_cfg=None):
sorted(self.tkzr_cfg.lookup, key=self.tkzr_cfg.lookup.get)
)
self.logger = Logger()
self.pool_cats = False
self.kernels = {"cubic": lambda x, a: 0.5 + a*4*(x-0.5)**3 + (1-a) * (x - 0.5),
"atanh": lambda x, a: 0.5 + 1/(2*a)*t.atanh(a*(2*x - 1)),
"piece_const": lambda x, a: t.where((x > 0) & (x < 1), t.zeros_like(x), 2*(x - 0.5)),
"linear": lambda x, a: x
}

self.loss_functions = {
"mse": t.nn.MSELoss,
"mae": t.nn.L1Loss,
"smooth_mae": t.nn.SmoothL1Loss
"smooth_mae": t.nn.SmoothL1Loss,
}

if "label_weighted_loss" in self.cfg:
Expand All @@ -46,6 +48,9 @@ def __init__(self, cfg=None, tkzr_cfg=None):
"Quantile token loss is still experimental for fused tokenizers."
)

if "pool_categories" in self.cfg.quantile_token_loss:
self.pool_cats = True

self.q_type = np.array(
[
v.endswith(tuple(f"Q{i}" for i in range(self.tkzr_cfg.cfg.n_bins)))
Expand Down Expand Up @@ -79,15 +84,17 @@ def __init__(self, cfg=None, tkzr_cfg=None):
self.kernel_type = self.cfg.quantile_token_loss.kernel.type
else:
self.kernel_type = "linear"
self.kernel_factor = 1.0
if "factor" in self.cfg.quantile_token_loss.kernel:
self.kernel_factor = self.cfg.quantile_token_loss.kernel.factor


def quantile_token_loss(self, outputs, labels, **kwargs):
total_loss = t.zeros((), device=labels.device)
total_tokens = 0
shift_logits = outputs.get("logits")[:, :-1].contiguous()
shift_labels = labels[:, 1:].contiguous()
total_tokens = shift_labels.numel()
total_num_tokens = 0
def aggregate_first(cat_logits, cat_labels, i):
cat_preds = t.softmax(cat_logits, dim=-1) @ (
self.label_to_q[self.label_to_cat == i]
Expand All @@ -99,16 +106,16 @@ def aggregate_first(cat_logits, cat_labels, i):
cat_true = kernel(cat_true, self.kernel_factor)
return self.loss_function()(cat_preds, cat_true)
def loss_first(cat_logits, cat_labels, i):
cat_true = self.label_to_q.to(device=cat_labels.device)[cat_labels]
values =(
self.label_to_q[self.label_to_cat == i]
).to(device=cat_logits.device)
cat_true = self.label_to_q.to(device=cat_labels.device)[cat_labels] # (N,)
values = (self.label_to_q[self.label_to_cat == i]).to(device=cat_logits.device) # (K,)
if self.cfg.quantile_token_loss.kernel.type in self.kernels:
kernel = self.kernels[self.kernel_type]
values = kernel(values, self.kernel_factor)
cat_true = kernel(cat_true, self.kernel_factor)
cat_true_full = cat_true.unsqueeze(-1).expand(*cat_true.shape, values.shape[0])
losses = self.loss_function(reduction = 'none')(values, cat_true_full)
N, K = cat_true.shape[0], values.shape[0]
cat_true_full = cat_true.unsqueeze(-1).expand(N, K) # (N, K)
values_full = values.unsqueeze(0).expand(N, K) # (N, K)
losses = self.loss_function(reduction='none')(values_full, cat_true_full)
return (t.softmax(cat_logits, dim=-1) * losses).sum(dim=-1).mean()

for i in range(self.n_cats):
Expand All @@ -123,8 +130,44 @@ def loss_first(cat_logits, cat_labels, i):
else:
loss_del = aggregate_first(cat_logits, cat_labels, i)
total_loss += loss_del * n_i
total_num_tokens += n_i
return total_loss / max(total_tokens,1), total_num_tokens/total_tokens

def x_ent_loss_cat(self, outputs, labels, **kwargs):
logits = outputs.get("logits")
assert logits.shape[:2] == labels.shape, \
f"logits {logits.shape} and labels {labels.shape} disagree on (batch, seq)"
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
total_loss = t.zeros((), device=labels.device)
total_tokens = 0
label_to_cat = self.label_to_cat.to(device=labels.device)
token_cats = label_to_cat[shift_labels]

flat_logits = shift_logits.view(-1, shift_logits.size(-1))
flat_labels = shift_labels.view(-1)
flat_token_cats = token_cats.view(-1)
log_probs = t.log_softmax(flat_logits, dim=-1)

for i in range(self.n_cats):
mask = flat_token_cats == i
n_i = mask.sum().item()
if n_i == 0:
continue
cat_cols = (self.label_to_cat == i).to(labels.device)
log_p_cat = t.logsumexp(log_probs[mask][:, cat_cols], dim=-1)
loss_i = -log_p_cat.mean()
total_loss += loss_i * n_i
total_tokens += n_i
return total_loss / max(total_tokens,1)

non_numeric_mask = ~t.isin(flat_token_cats, t.arange(self.n_cats, device=labels.device))
n_non = non_numeric_mask.sum().item()
if n_non > 0:
loss_non = t.nn.functional.nll_loss(log_probs[non_numeric_mask], flat_labels[non_numeric_mask])
total_loss += loss_non * n_non
total_tokens += n_non

return total_loss / max(total_tokens, 1)

def label_weighted_loss(self, outputs, labels, **kwargs):
logits = outputs.get("logits") # (batch, seq_len, vocab_size)
Expand All @@ -150,13 +193,14 @@ def custom_loss(self, outputs, labels, **kwargs):
log |= {"label_weighted_loss": label_weighted_loss.item()}
loss += label_weighted_loss
else:
x_ent_loss = self.x_ent_loss(outputs, labels)
x_ent_loss = self.x_ent_loss(outputs, labels) if (not self.pool_cats) else self.x_ent_loss_cat(outputs, labels)
log |= {"x_ent_loss": x_ent_loss.item()}
loss += x_ent_loss
if "quantile_token_loss" in self.cfg:
quantile_token_loss = self.quantile_token_loss(outputs, labels)
log |= {"quantile_token_loss": quantile_token_loss.item()}
quantile_token_loss, frac_numeric = self.quantile_token_loss(outputs, labels)
log |= {"quantile_token_loss": quantile_token_loss.item()/frac_numeric, "frac_numeric": frac_numeric}
loss += self.cfg.quantile_token_loss.qt_weight * quantile_token_loss
loss = loss/2
if wandb.run is not None:
log |= {"custom_loss": loss.item()}
wandb.log(log)
Expand Down
55 changes: 37 additions & 18 deletions src/cotorra/scorer_rep_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@
make representation-based predictions on held-out data
"""

import glob as glob_module
import pathlib

import lightgbm as lgb
import numpy as np
import polars as pl
import tqdm
from omegaconf import OmegaConf
from sklearn.linear_model import LogisticRegression

from cotorra.logger import Logger


class RepBasedScorer:
def __init__(self, main_cfg: pathlib.Path | str = None, **kwargs):
def __init__(self, main_cfg: pathlib.Path | str = None, ridge: bool = False, **kwargs):
parsed = OmegaConf.load(
pathlib.Path(main_cfg if main_cfg is not None else "./config/main.yaml")
.expanduser()
Expand All @@ -30,12 +32,17 @@ def __init__(self, main_cfg: pathlib.Path | str = None, **kwargs):
)
self.tkzr_cfg = OmegaConf.load(self.processed_data_home / "tokenizer.yaml")
self.logger = Logger()
self.ridge = ridge

self.splits = ("train", "tuning", "held_out")

run_name = self.cfg.wandb.run_name
self.features = {
s: np.vstack(
pl.scan_parquet(self.processed_data_home / f"features-{s}.parquet")
pl.scan_parquet(sorted(
glob_module.glob(str(self.processed_data_home / f"features-{s}-{run_name}.parquet"))
+ glob_module.glob(str(self.processed_data_home / f"features-{s}-{run_name}-[0-9]*.parquet"))
))
.select("features")
.collect()
.to_series()
Expand All @@ -61,36 +68,48 @@ def score_label(self, target_token="DSCG//expired"):
self.labels["held_out"].select(cols[0]).collect().to_numpy().ravel()
)

bst = lgb.LGBMClassifier(min_data_in_leaf=5, num_leaves=64)
bst.fit(
X=self.features["train"][train_valid],
y=train_label[train_valid],
eval_set=[
(self.features["tuning"][tuning_valid], tuning_label[tuning_valid])
],
eval_metric="auc",
)

scores = np.nan * np.ones_like(held_out_valid)
scores[held_out_valid] = bst.predict_proba(
X=self.features["held_out"][held_out_valid]
)[:, 1]
if self.ridge:
bst = LogisticRegression(penalty="l2", C=10.0, max_iter=1000)
bst.fit(
X=self.features["train"][train_valid],
y=train_label[train_valid],
)
scores = np.nan * np.ones_like(held_out_valid, dtype=float)
scores[held_out_valid] = bst.predict_proba(
X=self.features["held_out"][held_out_valid]
)[:, 1]
else:
bst = lgb.LGBMClassifier(min_data_in_leaf=5, num_leaves=64)
bst.fit(
X=self.features["train"][train_valid],
y=train_label[train_valid],
eval_set=[
(self.features["tuning"][tuning_valid], tuning_label[tuning_valid])
],
eval_metric="auc",
)
scores = np.nan * np.ones_like(held_out_valid)
scores[held_out_valid] = bst.predict_proba(
X=self.features["held_out"][held_out_valid]
)[:, 1]

return scores

def score(self):
suffix = "ridge_score" if self.ridge else "rep_score"
res = dict()
for tt in tqdm.tqdm(self.cfg.score.target_tokens, position=0):
res[f"{tt}_rep_score"] = self.score_label(target_token=tt)
res[f"{tt}_{suffix}"] = self.score_label(target_token=tt)

return res

def save_all(self, verbose: bool = False):
scorer_tag = "ridge" if self.ridge else "rep-based"
(
df_res := self.labels["held_out"].with_columns(pl.from_dict(self.score()))
).sink_parquet(
self.processed_data_home
/ f"scores-rep-based-{self.cfg.wandb.run_name}.parquet"
/ f"scores-{scorer_tag}-{self.cfg.wandb.run_name}.parquet"
)

if verbose:
Expand Down