Skip to content

Commit dfc8d09

Browse files
committed
feat: add per-student SHAP explainability and fine-tuning feasibility report
Add SHAP TreeExplainer integration to the ML pipeline (Step 10b) that computes per-student feature attributions for retention, gateway math, gateway English, and low GPA models. SHAP values are stored as a JSON column alongside predictions, surfaced through the student detail API, and consumed by the LLM enrichment path for grounded explanations. - compute_shap_explanations() handles both XGBoost and RandomForest - enrich_with_llm() now accepts SHAP data for attribution-aware prompts - Student API returns parsed shap_explanations to frontend - Add shap>=0.44.0 dependency - Add model-client.ts Ollama/OpenAI dual-backend adapter - Add fine-tuning feasibility report with explainability analysis
1 parent ae971e4 commit dfc8d09

6 files changed

Lines changed: 581 additions & 5 deletions

File tree

ai_model/complete_ml_pipeline.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1046,6 +1046,119 @@ def assign_alert_level(risk_score):
10461046

10471047
print(f"Low GPA predictions generated")
10481048

1049+
# ============================================================================
1050+
# STEP 10b: PER-STUDENT SHAP EXPLANATIONS
1051+
# ============================================================================
1052+
print("\n" + "=" * 80)
1053+
print("STEP 10b: COMPUTING PER-STUDENT SHAP EXPLANATIONS")
1054+
print("=" * 80)
1055+
1056+
import shap
1057+
import json as _json
1058+
1059+
def compute_shap_explanations(model, X_data, feature_names, model_label, top_n=5):
1060+
"""
1061+
Compute per-student SHAP values using TreeExplainer.
1062+
1063+
For binary classifiers, uses class-1 (positive outcome) SHAP values.
1064+
Returns top N positive/negative contributors per student plus the full
1065+
SHAP vector for downstream use by the fine-tuned explainer.
1066+
"""
1067+
explainer = shap.TreeExplainer(model)
1068+
shap_values = explainer.shap_values(X_data)
1069+
1070+
# Binary classifiers: shap_values may be a list [class_0, class_1] (RandomForest),
1071+
# a 3D array (samples, features, classes), or a 2D array (XGBoost default).
1072+
if isinstance(shap_values, list):
1073+
sv = shap_values[1]
1074+
elif shap_values.ndim == 3:
1075+
sv = shap_values[:, :, 1]
1076+
else:
1077+
sv = shap_values
1078+
1079+
# Base value — expected model output before any feature contributions
1080+
base = explainer.expected_value
1081+
if isinstance(base, (list, np.ndarray)):
1082+
base_value = float(base[1]) if len(base) > 1 else float(base[0])
1083+
else:
1084+
base_value = float(base)
1085+
1086+
explanations = []
1087+
for i in range(len(X_data)):
1088+
row_shap = sv[i]
1089+
row_values = X_data.iloc[i] if hasattr(X_data, 'iloc') else X_data[i]
1090+
1091+
# Build (feature_name, shap_value, feature_value) tuples
1092+
feature_contribs = []
1093+
for j, fname in enumerate(feature_names):
1094+
fval = row_values.iloc[j] if hasattr(row_values, 'iloc') else row_values[j]
1095+
feature_contribs.append({
1096+
"feature": fname,
1097+
"shap_value": round(float(row_shap[j]), 4),
1098+
"value": float(fval) if isinstance(fval, (int, float, np.integer, np.floating)) else str(fval),
1099+
})
1100+
1101+
sorted_pos = sorted(
1102+
[f for f in feature_contribs if f["shap_value"] > 0],
1103+
key=lambda x: x["shap_value"], reverse=True
1104+
)[:top_n]
1105+
1106+
sorted_neg = sorted(
1107+
[f for f in feature_contribs if f["shap_value"] < 0],
1108+
key=lambda x: x["shap_value"]
1109+
)[:top_n]
1110+
1111+
explanations.append({
1112+
"base_value": round(base_value, 4),
1113+
"top_positive": sorted_pos,
1114+
"top_negative": sorted_neg,
1115+
"all_contributions": feature_contribs,
1116+
})
1117+
1118+
return explanations
1119+
1120+
1121+
# Models to explain with SHAP (all 4 XGBoost/RF classifiers)
1122+
shap_targets = {
1123+
"retention": (retention_model, X_full_retention, retention_features),
1124+
"gateway_math": (gateway_math_model, X_full_gateway_math, gateway_math_features),
1125+
"gateway_english": (gateway_english_model, X_full_gateway_english, gateway_english_features),
1126+
"low_gpa": (low_gpa_model, X_gpa_clean, gpa_features),
1127+
}
1128+
1129+
shap_results = {}
1130+
for label, (model, X_data, features) in shap_targets.items():
1131+
print(f"\nComputing SHAP explanations for {label} model...")
1132+
explanations = compute_shap_explanations(model, X_data, features, label)
1133+
shap_results[label] = explanations
1134+
print(f" ✓ {len(explanations)} student explanations generated")
1135+
if explanations:
1136+
ex = explanations[0]
1137+
print(f" Sample (student 0): base_value={ex['base_value']}")
1138+
for f in ex['top_positive'][:3]:
1139+
print(f" ↑ {f['feature']}: +{f['shap_value']}")
1140+
for f in ex['top_negative'][:3]:
1141+
print(f" ↓ {f['feature']}: {f['shap_value']}")
1142+
1143+
# Attach SHAP explanations as JSON column on the main dataframe
1144+
# Stores only top contributors per model to keep DB size manageable
1145+
print("\nAttaching SHAP explanations to student dataframe...")
1146+
shap_json_col = []
1147+
for i in range(len(df)):
1148+
student_shap = {}
1149+
for label, explanations in shap_results.items():
1150+
if i < len(explanations):
1151+
ex = explanations[i]
1152+
student_shap[label] = {
1153+
"base_value": ex["base_value"],
1154+
"top_positive": ex["top_positive"],
1155+
"top_negative": ex["top_negative"],
1156+
}
1157+
shap_json_col.append(_json.dumps(student_shap))
1158+
1159+
df['shap_explanations'] = shap_json_col
1160+
print(f"✓ SHAP explanations attached as JSON column ({len(shap_json_col):,} students)")
1161+
10491162
# ============================================================================
10501163
# STEP 11: SAVE PREDICTIONS TO STUDENT-LEVEL FILE
10511164
# ============================================================================
@@ -1063,7 +1176,8 @@ def assign_alert_level(risk_score):
10631176
'prob_no_credential', 'prob_certificate', 'prob_associate', 'prob_bachelor',
10641177
'gateway_math_probability', 'gateway_math_prediction', 'gateway_math_risk',
10651178
'gateway_english_probability', 'gateway_english_prediction', 'gateway_english_risk',
1066-
'low_gpa_probability', 'low_gpa_prediction', 'academic_risk_level'
1179+
'low_gpa_probability', 'low_gpa_prediction', 'academic_risk_level',
1180+
'shap_explanations'
10671181
]
10681182

10691183
predictions_df = df[prediction_columns].copy()

ai_model/generate_readiness_scores.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -447,13 +447,17 @@ def score_student(row) -> dict:
447447
# LLM Enrichment (optional)
448448
# ============================================================================
449449

450-
def enrich_with_llm(record: dict, model: str) -> dict:
450+
def enrich_with_llm(record: dict, model: str, shap_data: dict = None) -> dict:
451451
"""
452452
Replace rationale and suggested_actions with LLM-generated content.
453453
Only called for medium/low readiness students.
454454
Input is the FERPA-safe profile — no PII sent to any external service.
455455
Returns the record with enriched text fields (score unchanged).
456456
457+
When shap_data is provided (from the ML pipeline's SHAP step), the prompt
458+
includes per-model feature attribution so the LLM can ground its
459+
explanation in what the models actually learned.
460+
457461
Provider is determined by the model string:
458462
"gpt-4o-mini" -> OpenAI (requires OPENAI_API_KEY)
459463
"ollama/llama3.2:3b" -> local Ollama (no key needed)
@@ -469,6 +473,25 @@ def enrich_with_llm(record: dict, model: str) -> dict:
469473
profile = json.loads(record["input_features"]) if isinstance(record["input_features"], str) else record["input_features"]
470474
risk_factors = json.loads(record["risk_factors"]) if isinstance(record["risk_factors"], str) else []
471475

476+
# Build SHAP context section if available
477+
shap_section = ""
478+
if shap_data:
479+
shap_lines = []
480+
for model_name, attrs in shap_data.items():
481+
shap_lines.append(f"\n {model_name} model (base prediction: {attrs.get('base_value', 'N/A')}):")
482+
for f in attrs.get("top_positive", []):
483+
shap_lines.append(f" ↑ {f['feature']} = {f['value']} (pushes prediction UP by {f['shap_value']})")
484+
for f in attrs.get("top_negative", []):
485+
shap_lines.append(f" ↓ {f['feature']} = {f['value']} (pushes prediction DOWN by {abs(f['shap_value'])})")
486+
shap_section = f"""
487+
488+
ML Model Feature Attribution (SHAP — shows which features drive each prediction):
489+
{''.join(shap_lines)}
490+
491+
IMPORTANT: Use these SHAP values to ground your explanation. Tell the advisor
492+
which specific factors are most responsible for this student's risk level,
493+
citing the magnitude. Do not speculate beyond what the models show."""
494+
472495
prompt = f"""You are an academic advisor assistant at Bishop State Community College.
473496
A student has a readiness score of {record['readiness_score']:.2f} ({record['readiness_level']} readiness).
474497
@@ -484,10 +507,10 @@ def enrich_with_llm(record: dict, model: str) -> dict:
484507
- Retention probability: {profile.get('retention_probability')}
485508
486509
Identified risk factors:
487-
{chr(10).join(f'- {f}' for f in risk_factors)}
510+
{chr(10).join(f'- {f}' for f in risk_factors)}{shap_section}
488511
489512
Write two things:
490-
1. RATIONALE: A 2-sentence explanation of this student's readiness score for an advisor.
513+
1. RATIONALE: A 2-3 sentence explanation of this student's readiness score for an advisor. If SHAP data is available, cite the top contributing factors by name and magnitude.
491514
2. ACTIONS: A JSON array of 3-5 specific, actionable intervention recommendations (strings only).
492515
493516
Format your response exactly as:
@@ -588,7 +611,15 @@ def main():
588611
record["generation_ms"] = elapsed_ms
589612
record["run_id"] = run_id
590613
if args.enrich_with_llm and record["readiness_level"] in ("medium", "low"):
591-
record = enrich_with_llm(record, args.llm_model)
614+
# Pass SHAP data if the shap_explanations column exists
615+
shap_data = None
616+
shap_raw = row.get("shap_explanations")
617+
if shap_raw and str(shap_raw) not in ("", "nan", "None"):
618+
try:
619+
shap_data = json.loads(shap_raw) if isinstance(shap_raw, str) else shap_raw
620+
except (json.JSONDecodeError, TypeError):
621+
pass
622+
record = enrich_with_llm(record, args.llm_model, shap_data=shap_data)
592623
records.append(record)
593624
except Exception as e:
594625
errors += 1

codebenders-dashboard/app/api/students/[guid]/route.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ export async function GET(
2929
ROUND((s.low_gpa_probability * 100)::numeric, 1) AS gpa_risk_pct,
3030
ROUND(s.predicted_time_to_credential::numeric, 1) AS time_to_credential,
3131
s.predicted_credential_label AS credential_type,
32+
s.shap_explanations,
3233
ROUND((r.readiness_score * 100)::numeric, 1) AS readiness_pct,
3334
r.readiness_level,
3435
r.rationale,
@@ -51,10 +52,21 @@ export async function GET(
5152
}
5253

5354
const row = result.rows[0]
55+
// Parse JSON string columns into objects for the frontend
56+
let shap = null
57+
if (row.shap_explanations) {
58+
try {
59+
shap = typeof row.shap_explanations === "string"
60+
? JSON.parse(row.shap_explanations)
61+
: row.shap_explanations
62+
} catch { shap = null }
63+
}
64+
5465
return NextResponse.json({
5566
...row,
5667
risk_factors: row.risk_factors ? JSON.parse(row.risk_factors) : [],
5768
suggested_actions: row.suggested_actions ? JSON.parse(row.suggested_actions) : [],
69+
shap_explanations: shap,
5870
})
5971
} catch (error) {
6072
console.error("Student detail fetch error:", error)
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/**
2+
* Model client adapter — routes inference to Ollama (fine-tuned) or
3+
* OpenAI (fallback) based on MODEL_BACKEND env var.
4+
*/
5+
6+
import { generateText } from "ai"
7+
import { createOpenAI } from "@ai-sdk/openai"
8+
9+
const MODEL_BACKEND = process.env.MODEL_BACKEND || "openai"
10+
const SCHOOL_CODE = process.env.SCHOOL_CODE || "bishop-state"
11+
const OLLAMA_BASE_URL = process.env.OLLAMA_BASE_URL || "http://localhost:11434"
12+
const MODEL_SIZE = process.env.MODEL_SIZE || "9b"
13+
14+
let _openai: ReturnType<typeof createOpenAI> | null = null
15+
16+
function getOpenAI() {
17+
if (!_openai) {
18+
_openai = createOpenAI({ apiKey: process.env.OPENAI_API_KEY || "" })
19+
}
20+
return _openai
21+
}
22+
23+
async function callOllama(model: string, prompt: string, maxTokens: number): Promise<string> {
24+
const response = await fetch(`${OLLAMA_BASE_URL}/api/generate`, {
25+
method: "POST",
26+
headers: { "Content-Type": "application/json" },
27+
body: JSON.stringify({
28+
model,
29+
prompt,
30+
stream: false,
31+
options: {
32+
temperature: 0.3,
33+
num_predict: maxTokens,
34+
},
35+
}),
36+
})
37+
38+
if (!response.ok) {
39+
throw new Error(`Ollama error: ${response.status} ${response.statusText}`)
40+
}
41+
42+
const data = await response.json()
43+
return data.response
44+
}
45+
46+
async function generate(
47+
task: "explainer" | "summarizer",
48+
prompt: string,
49+
maxTokens: number,
50+
): Promise<string> {
51+
if (MODEL_BACKEND === "ollama") {
52+
const model = `${SCHOOL_CODE}-${task}:${MODEL_SIZE}`
53+
return callOllama(model, prompt, maxTokens)
54+
}
55+
const result = await generateText({
56+
model: getOpenAI()("gpt-4o-mini"),
57+
prompt,
58+
maxOutputTokens: maxTokens,
59+
})
60+
return result.text
61+
}
62+
63+
/**
64+
* Generate a course pairing explanation.
65+
*/
66+
export async function generateExplanation(
67+
prompt: string,
68+
maxTokens: number = 320,
69+
): Promise<string> {
70+
return generate("explainer", prompt, maxTokens)
71+
}
72+
73+
/**
74+
* Generate a query result summary.
75+
*/
76+
export async function generateSummary(
77+
prompt: string,
78+
maxTokens: number = 200,
79+
): Promise<string> {
80+
return generate("summarizer", prompt, maxTokens)
81+
}

0 commit comments

Comments
 (0)