Skip to content

Commit e2911ca

Browse files
committed
refactor: simplify narrator integration per code review
- Remove dead generate_explainer_pairs/generate_summarizer_pairs wrappers - Derive _REQUIRED_KEYS sets from schema dicts in prompts.py - Simplify dedup loop in check_shap_grounding with dict.fromkeys() - Add narrator key to load_seed_queries() return - Update test_distill.py to use generate_pairs() directly - Fix stale docstring in generate_pairs()
1 parent ac42cfb commit e2911ca

4 files changed

Lines changed: 15 additions & 56 deletions

File tree

tests/training/test_distill.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
from training.distill import (
88
validate_json,
99
call_teacher,
10-
generate_explainer_pairs,
11-
generate_summarizer_pairs,
10+
generate_pairs,
1211
)
1312

1413

@@ -80,10 +79,11 @@ def test_generates_pairs_from_seed_data(self, sample_school_config, sample_cours
8079
})
8180

8281
with patch("training.distill.call_teacher", return_value=mock_response):
83-
pairs = generate_explainer_pairs(
82+
pairs = generate_pairs(
8483
config=sample_school_config,
8584
seed_data=[sample_course_pairing_data],
8685
count=2,
86+
task="explainer",
8787
)
8888

8989
assert len(pairs) == 2
@@ -92,10 +92,11 @@ def test_generates_pairs_from_seed_data(self, sample_school_config, sample_cours
9292

9393
def test_skips_invalid_responses(self, sample_school_config, sample_course_pairing_data):
9494
with patch("training.distill.call_teacher", return_value="not json"):
95-
pairs = generate_explainer_pairs(
95+
pairs = generate_pairs(
9696
config=sample_school_config,
9797
seed_data=[sample_course_pairing_data],
9898
count=3,
99+
task="explainer",
99100
)
100101

101102
assert len(pairs) == 0
@@ -112,10 +113,11 @@ def test_generates_pairs_from_seed_data(self, sample_school_config, sample_query
112113
})
113114

114115
with patch("training.distill.call_teacher", return_value=mock_response):
115-
pairs = generate_summarizer_pairs(
116+
pairs = generate_pairs(
116117
config=sample_school_config,
117118
seed_data=[sample_query_result_data],
118119
count=2,
120+
task="summarizer",
119121
)
120122

121123
assert len(pairs) == 2

training/distill.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def generate_pairs(
171171
config: Parsed school config.
172172
seed_data: List of seed data dicts.
173173
count: Number of pairs to generate.
174-
task: "explainer" or "summarizer".
174+
task: "narrator", "explainer", or "summarizer".
175175
outfile: If provided, pairs are written incrementally.
176176
system_prompt: Pre-built system prompt (avoids recomputation).
177177
"""
@@ -221,24 +221,6 @@ def generate_pairs(
221221
return pairs
222222

223223

224-
def generate_explainer_pairs(
225-
config: dict[str, Any], seed_data: list[dict[str, Any]],
226-
count: int, outfile: Path | None = None,
227-
system_prompt: str | None = None,
228-
) -> list[dict]:
229-
"""Generate explainer training pairs via teacher model distillation."""
230-
return generate_pairs(config, seed_data, count, "explainer", outfile, system_prompt)
231-
232-
233-
def generate_summarizer_pairs(
234-
config: dict[str, Any], seed_data: list[dict[str, Any]],
235-
count: int, outfile: Path | None = None,
236-
system_prompt: str | None = None,
237-
) -> list[dict]:
238-
"""Generate summarizer training pairs via teacher model distillation."""
239-
return generate_pairs(config, seed_data, count, "summarizer", outfile, system_prompt)
240-
241-
242224
def main(school: str, local: bool = False) -> None:
243225
"""Run distillation for a school."""
244226
config = load_school_config(school)

training/eval.py

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,34 +17,15 @@
1717
from typing import Any
1818

1919
from training.config import get_message_content, get_training_data_dir, read_jsonl
20+
from training.prompts import EXPLAINER_SCHEMA, NARRATOR_SCHEMA, SUMMARIZER_SCHEMA
2021

2122
# ---------------------------------------------------------------------------
22-
# Required keys per task
23+
# Required keys per task — derived from schema definitions in prompts.py
2324
# ---------------------------------------------------------------------------
2425

25-
_EXPLAINER_REQUIRED_KEYS: set[str] = {
26-
"explanation",
27-
"structural_factors",
28-
"student_impact",
29-
"advisor_recommendation",
30-
"data_limitations",
31-
"related_intervention",
32-
}
33-
34-
_NARRATOR_REQUIRED_KEYS: set[str] = {
35-
"narrative",
36-
"key_drivers",
37-
"recommended_actions",
38-
"data_limitations",
39-
}
40-
41-
_SUMMARIZER_REQUIRED_KEYS: set[str] = {
42-
"summary",
43-
"key_insights",
44-
"context",
45-
"action_items",
46-
"caveats",
47-
}
26+
_EXPLAINER_REQUIRED_KEYS: set[str] = set(EXPLAINER_SCHEMA.keys())
27+
_NARRATOR_REQUIRED_KEYS: set[str] = set(NARRATOR_SCHEMA.keys())
28+
_SUMMARIZER_REQUIRED_KEYS: set[str] = set(SUMMARIZER_SCHEMA.keys())
4829

4930
# ---------------------------------------------------------------------------
5031
# Ship criteria — minimum thresholds per task
@@ -204,14 +185,7 @@ def check_shap_grounding(outputs: list[str], inputs: list[dict[str, Any]], min_f
204185
top_features.append(entry["feature"])
205186
for entry in model_attrs.get("top_negative", [])[:3]:
206187
top_features.append(entry["feature"])
207-
# Deduplicate while preserving order
208-
seen = set()
209-
unique_features = []
210-
for f in top_features:
211-
if f not in seen:
212-
seen.add(f)
213-
unique_features.append(f)
214-
top_features = unique_features[:6] # top 3 per direction, deduplicated
188+
top_features = list(dict.fromkeys(top_features))[:6]
215189

216190
if not top_features:
217191
passing += 1 # no SHAP data to ground against

training/seed.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def load_seed_queries(school: str) -> dict[str, list[dict]]:
168168
with seed_path.open("r", encoding="utf-8") as fh:
169169
data = yaml.safe_load(fh) or {}
170170
return {
171+
"narrator": data.get("narrator", []),
171172
"explainer": data.get("explainer", []),
172173
"summarizer": data.get("summarizer", []),
173174
}

0 commit comments

Comments
 (0)