Skip to content

Commit ac42cfb

Browse files
committed
feat(#97): add SHAP narrator task type to training pipeline
Add narrator as a third task alongside explainer and summarizer. The narrator takes per-student SHAP values + profile and generates advisor-facing narratives grounded in ML feature attribution. - prompts.py: NARRATOR_SCHEMA, NARRATOR_STUDENT_SYSTEM, build_narrator_prompt() - seed.py: generate_synthetic_student_profiles() with SHAP data - distill.py: narrator in _TASK_CONFIG, included in main() distillation loop - eval.py: _NARRATOR_REQUIRED_KEYS, shap_grounding ship criterion (>= 80%), check_shap_grounding() metric (counts feature name mentions in narrative) - prepare.py: narrator added to task iteration
1 parent 807bb78 commit ac42cfb

5 files changed

Lines changed: 282 additions & 17 deletions

File tree

training/distill.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,18 @@
2121
from training.config import get_training_data_dir, load_school_config, write_jsonl
2222
from training.prompts import (
2323
EXPLAINER_STUDENT_SYSTEM,
24+
NARRATOR_STUDENT_SYSTEM,
2425
SUMMARIZER_STUDENT_SYSTEM,
2526
build_explainer_prompt,
27+
build_narrator_prompt,
2628
build_summarizer_prompt,
2729
build_system_prompt,
2830
)
2931
from training.seed import (
3032
format_as_chatml,
3133
generate_synthetic_course_pairings,
3234
generate_synthetic_query_results,
35+
generate_synthetic_student_profiles,
3336
)
3437

3538
# Cost tracking
@@ -136,6 +139,11 @@ def call_teacher(system: str, user: str, backend: str, model: str) -> str:
136139
_FLUSH_INTERVAL = 25
137140

138141
_TASK_CONFIG = {
142+
"narrator": {
143+
"prompt_builder": build_narrator_prompt,
144+
"student_system": NARRATOR_STUDENT_SYSTEM,
145+
"format_user": lambda config, data: json.dumps(data, ensure_ascii=False, default=str),
146+
},
139147
"explainer": {
140148
"prompt_builder": build_explainer_prompt,
141149
"student_system": EXPLAINER_STUDENT_SYSTEM,
@@ -245,28 +253,43 @@ def main(school: str, local: bool = False) -> None:
245253
data_dir = get_training_data_dir(school)
246254
pairs_dir = data_dir / "pairs"
247255

248-
synthetic_pairings = generate_synthetic_course_pairings(config, count=pairs_per_task)
249-
synthetic_results = generate_synthetic_query_results(config, count=pairs_per_task)
250-
251256
system_prompt = build_system_prompt(config)
252257

258+
all_counts: dict[str, int] = {}
259+
260+
# Narrator
261+
print(f"\n{'='*60}\nNARRATOR — generating {pairs_per_task} pairs\n{'='*60}")
262+
synthetic_profiles = generate_synthetic_student_profiles(config, count=pairs_per_task)
263+
narrator_pairs = generate_pairs(
264+
config=config, seed_data=synthetic_profiles,
265+
count=pairs_per_task, task="narrator", outfile=pairs_dir / "narrator.jsonl",
266+
system_prompt=system_prompt,
267+
)
268+
all_counts["narrator"] = len(narrator_pairs)
269+
270+
# Explainer
253271
print(f"\n{'='*60}\nEXPLAINER — generating {pairs_per_task} pairs\n{'='*60}")
254-
explainer_pairs = generate_explainer_pairs(
272+
synthetic_pairings = generate_synthetic_course_pairings(config, count=pairs_per_task)
273+
explainer_pairs = generate_pairs(
255274
config=config, seed_data=synthetic_pairings,
256-
count=pairs_per_task, outfile=pairs_dir / "explainer.jsonl",
275+
count=pairs_per_task, task="explainer", outfile=pairs_dir / "explainer.jsonl",
257276
system_prompt=system_prompt,
258277
)
278+
all_counts["explainer"] = len(explainer_pairs)
259279

280+
# Summarizer
260281
print(f"\n{'='*60}\nSUMMARIZER — generating {pairs_per_task} pairs\n{'='*60}")
261-
summarizer_pairs = generate_summarizer_pairs(
282+
synthetic_results = generate_synthetic_query_results(config, count=pairs_per_task)
283+
summarizer_pairs = generate_pairs(
262284
config=config, seed_data=synthetic_results,
263-
count=pairs_per_task, outfile=pairs_dir / "summarizer.jsonl",
285+
count=pairs_per_task, task="summarizer", outfile=pairs_dir / "summarizer.jsonl",
264286
system_prompt=system_prompt,
265287
)
288+
all_counts["summarizer"] = len(summarizer_pairs)
266289

267290
print(f"\n{'='*60}\nDISTILLATION COMPLETE\n{'='*60}")
268-
print(f" Explainer: {len(explainer_pairs)} pairs")
269-
print(f" Summarizer: {len(summarizer_pairs)} pairs")
291+
for task_name, count in all_counts.items():
292+
print(f" {task_name.capitalize()}: {count} pairs")
270293
_print_cost_summary()
271294

272295

training/eval.py

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@
3131
"related_intervention",
3232
}
3333

34+
_NARRATOR_REQUIRED_KEYS: set[str] = {
35+
"narrative",
36+
"key_drivers",
37+
"recommended_actions",
38+
"data_limitations",
39+
}
40+
3441
_SUMMARIZER_REQUIRED_KEYS: set[str] = {
3542
"summary",
3643
"key_insights",
@@ -44,6 +51,12 @@
4451
# ---------------------------------------------------------------------------
4552

4653
SHIP_CRITERIA: dict[str, dict[str, float]] = {
54+
"narrator": {
55+
"json_validity": 0.95,
56+
"schema_adherence": 0.90,
57+
"shap_grounding": 0.80,
58+
"caveat_inclusion": 0.85,
59+
},
4760
"explainer": {
4861
"json_validity": 0.95,
4962
"schema_adherence": 0.90,
@@ -120,9 +133,11 @@ def check_schema_adherence(outputs: list[str], task: str) -> float:
120133
"""Fraction of valid JSON outputs that contain all required keys."""
121134
if not outputs:
122135
return 0.0
123-
required = (
124-
_EXPLAINER_REQUIRED_KEYS if task == "explainer" else _SUMMARIZER_REQUIRED_KEYS
125-
)
136+
required = {
137+
"narrator": _NARRATOR_REQUIRED_KEYS,
138+
"explainer": _EXPLAINER_REQUIRED_KEYS,
139+
"summarizer": _SUMMARIZER_REQUIRED_KEYS,
140+
}.get(task, _SUMMARIZER_REQUIRED_KEYS)
126141
passing = 0
127142
total = 0
128143
for text in outputs:
@@ -147,7 +162,7 @@ def check_caveat_inclusion(outputs: list[str], task: str) -> float:
147162
"""
148163
if not outputs:
149164
return 0.0
150-
caveat_key = "data_limitations" if task == "explainer" else "caveats"
165+
caveat_key = "caveats" if task == "summarizer" else "data_limitations"
151166
passing = 0
152167
total = 0
153168
for text in outputs:
@@ -169,6 +184,51 @@ def check_caveat_inclusion(outputs: list[str], task: str) -> float:
169184
return passing / total if total else 0.0
170185

171186

187+
def check_shap_grounding(outputs: list[str], inputs: list[dict[str, Any]], min_features: int = 2) -> float:
188+
"""Fraction of narrator outputs that mention at least `min_features` of the top-3 SHAP features.
189+
190+
Extracts feature names from the input's SHAP data and checks whether the
191+
narrative text references them (case-insensitive, underscore-tolerant).
192+
"""
193+
if not outputs:
194+
return 0.0
195+
passing = 0
196+
total = 0
197+
for output_text, input_data in zip(outputs, inputs):
198+
total += 1
199+
# Collect top SHAP feature names from all models in the input
200+
shap_data = input_data.get("shap", {})
201+
top_features: list[str] = []
202+
for model_attrs in shap_data.values():
203+
for entry in model_attrs.get("top_positive", [])[:3]:
204+
top_features.append(entry["feature"])
205+
for entry in model_attrs.get("top_negative", [])[:3]:
206+
top_features.append(entry["feature"])
207+
# Deduplicate while preserving order
208+
seen = set()
209+
unique_features = []
210+
for f in top_features:
211+
if f not in seen:
212+
seen.add(f)
213+
unique_features.append(f)
214+
top_features = unique_features[:6] # top 3 per direction, deduplicated
215+
216+
if not top_features:
217+
passing += 1 # no SHAP data to ground against
218+
continue
219+
220+
# Check how many features appear in the output (case-insensitive, underscores → spaces)
221+
output_lower = output_text.lower().replace("_", " ")
222+
mentioned = sum(
223+
1 for f in top_features
224+
if f.lower().replace("_", " ") in output_lower
225+
)
226+
if mentioned >= min_features:
227+
passing += 1
228+
229+
return passing / total if total else 0.0
230+
231+
172232
def check_factual_grounding(outputs: list[str], inputs: list[dict[str, Any]]) -> float:
173233
"""Fraction of outputs that contain numeric values referenced in their input.
174234
@@ -314,8 +374,11 @@ def run_eval(school: str, task: str) -> ShipDecision:
314374
"json_validity": check_json_validity(outputs),
315375
"schema_adherence": check_schema_adherence(outputs, task),
316376
"caveat_inclusion": check_caveat_inclusion(outputs, task),
317-
"factual_grounding": check_factual_grounding(outputs, inputs),
318377
}
378+
if task == "narrator":
379+
metrics["shap_grounding"] = check_shap_grounding(outputs, inputs)
380+
else:
381+
metrics["factual_grounding"] = check_factual_grounding(outputs, inputs)
319382

320383
print(f"\n[eval] Results for {school}/{task}:")
321384
for k, v in metrics.items():
@@ -337,13 +400,13 @@ def main() -> None:
337400
parser.add_argument("--school", required=True, help="School directory name (e.g. bishop-state)")
338401
parser.add_argument(
339402
"--task",
340-
choices=["explainer", "summarizer"],
403+
choices=["narrator", "explainer", "summarizer"],
341404
default=None,
342405
help="Task to evaluate (default: both)",
343406
)
344407
args = parser.parse_args()
345408

346-
tasks = [args.task] if args.task else ["explainer", "summarizer"]
409+
tasks = [args.task] if args.task else ["narrator", "explainer", "summarizer"]
347410
results: dict[str, ShipDecision] = {}
348411
for task in tasks:
349412
print(f"\n{'='*60}\nEVAL: {task.upper()}\n{'='*60}")

training/prepare.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def process_task(school: str, task: str) -> dict[str, int]:
133133

134134
def main(school: str) -> None:
135135
"""Run preparation for all tasks."""
136-
for task in ("explainer", "summarizer"):
136+
for task in ("narrator", "explainer", "summarizer"):
137137
try:
138138
process_task(school, task)
139139
except FileNotFoundError as e:

training/prompts.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,28 @@
2626
"caveats": ["data limitations relevant to this specific query"],
2727
}
2828

29+
NARRATOR_SCHEMA = {
30+
"narrative": "2-3 sentence explanation grounded in SHAP feature attribution",
31+
"key_drivers": ["ranked list of factors with direction and magnitude"],
32+
"recommended_actions": ["3-5 specific, actionable interventions"],
33+
"data_limitations": ["caveats about the prediction"],
34+
}
35+
2936
EXPLAINER_STUDENT_SYSTEM = (
3037
"You are a student success analyst. Given course pairing data, generate a "
3138
"structured JSON explanation. Include: explanation, structural_factors, "
3239
"student_impact, advisor_recommendation, data_limitations, and "
3340
"related_intervention. Respond with ONLY valid JSON."
3441
)
3542

43+
NARRATOR_STUDENT_SYSTEM = (
44+
"You are a student success analyst. Given a student profile with ML prediction "
45+
"attribution (SHAP values), generate a structured JSON explanation. Include: "
46+
"narrative, key_drivers, recommended_actions, and data_limitations. "
47+
"Ground your narrative in the SHAP values — cite specific features by name "
48+
"and magnitude. Respond with ONLY valid JSON."
49+
)
50+
3651
SUMMARIZER_STUDENT_SYSTEM = (
3752
"You are a student success analyst. Given a query and its results, generate "
3853
"a structured JSON summary. Include: summary, key_insights, context, "
@@ -195,6 +210,62 @@ def build_system_prompt(config: dict[str, Any]) -> str:
195210
return "\n\n".join(sections)
196211

197212

213+
def build_narrator_prompt(
214+
config: dict[str, Any],
215+
student_data: dict[str, Any],
216+
) -> str:
217+
"""Build the teacher prompt for generating a SHAP-grounded student narrative."""
218+
schema_str = json.dumps(NARRATOR_SCHEMA, indent=2)
219+
profile = student_data.get("student_profile", {})
220+
shap_data = student_data.get("shap", {})
221+
risk_factors = student_data.get("risk_factors", [])
222+
readiness_score = student_data.get("readiness_score", "N/A")
223+
readiness_level = student_data.get("readiness_level", "unknown")
224+
225+
# Format SHAP attribution section
226+
shap_lines = []
227+
for model_name, attrs in shap_data.items():
228+
shap_lines.append(f"\n {model_name} model (base prediction: {attrs.get('base_value', 'N/A')}):")
229+
for f in attrs.get("top_positive", []):
230+
shap_lines.append(f" + {f['feature']} = {f['value']} (pushes prediction UP by {f['shap_value']})")
231+
for f in attrs.get("top_negative", []):
232+
shap_lines.append(f" - {f['feature']} = {f['value']} (pushes prediction DOWN by {abs(f['shap_value'])})")
233+
234+
profile_str = json.dumps(profile, indent=2, default=str)
235+
risk_str = "\n".join(f"- {r}" for r in risk_factors) if risk_factors else "None identified"
236+
237+
interventions = config.get("school", {}).get("interventions", {}).get("active", [])
238+
intervention_lines = []
239+
for i in interventions:
240+
intervention_lines.append(f"- {i['name']} ({i['type']}): {i.get('effectiveness', 'unknown')}")
241+
interventions_str = "\n".join(intervention_lines) if intervention_lines else "None listed"
242+
243+
return f"""A student at this institution has a readiness score of {readiness_score} ({readiness_level}).
244+
Analyze their ML prediction factors and write an advisor-facing explanation.
245+
246+
STUDENT PROFILE:
247+
{profile_str}
248+
249+
RISK FACTORS (rule-engine identified):
250+
{risk_str}
251+
252+
ML MODEL FEATURE ATTRIBUTION (SHAP values — what drives each prediction):
253+
{''.join(shap_lines) if shap_lines else 'No SHAP data available'}
254+
255+
AVAILABLE INTERVENTIONS:
256+
{interventions_str}
257+
258+
Generate a JSON response with this exact schema:
259+
{schema_str}
260+
261+
Guidelines:
262+
- Ground the narrative in SHAP values. Cite at least 2 of the top contributing features by name and magnitude.
263+
- Explain in plain language what each factor means for this student's likelihood of success.
264+
- Make recommended actions specific to this institution — reference active interventions by name when relevant.
265+
- Include at least one data limitation or caveat about the prediction.
266+
- Do NOT speculate beyond what the SHAP values and profile data show."""
267+
268+
198269
def build_explainer_prompt(
199270
config: dict[str, Any],
200271
course_data: dict[str, Any],

0 commit comments

Comments
 (0)