Skip to content

Commit 807bb78

Browse files
committed
Merge branch 'feature/shap-explainability' into fine-tuning/student-explainability
2 parents 29421fb + baff2b2 commit 807bb78

5 files changed

Lines changed: 482 additions & 7 deletions

File tree

ai_model/complete_ml_pipeline.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
)
2323
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
2424
import xgboost as xgb
25+
import shap
26+
import json
2527
from datetime import datetime
2628
import warnings
2729
warnings.filterwarnings('ignore')
@@ -1046,6 +1048,98 @@ def assign_alert_level(risk_score):
10461048

10471049
print(f"Low GPA predictions generated")
10481050

1051+
# ============================================================================
1052+
# STEP 10b: PER-STUDENT SHAP EXPLANATIONS
1053+
# ============================================================================
1054+
print("\n" + "=" * 80)
1055+
print("STEP 10b: COMPUTING PER-STUDENT SHAP EXPLANATIONS")
1056+
print("=" * 80)
1057+
1058+
1059+
def compute_shap_explanations(model, X_data, feature_names, top_n=5):
1060+
"""
1061+
Compute per-student SHAP values using TreeExplainer.
1062+
1063+
For binary classifiers, uses class-1 (positive outcome) SHAP values.
1064+
Returns top N positive/negative contributors per student.
1065+
"""
1066+
explainer = shap.TreeExplainer(model)
1067+
shap_values = explainer.shap_values(X_data)
1068+
1069+
# Binary classifiers: RandomForest returns list [class_0, class_1],
1070+
# some models return 3D (samples, features, classes), XGBoost returns 2D.
1071+
if isinstance(shap_values, list):
1072+
sv = shap_values[1]
1073+
elif shap_values.ndim == 3:
1074+
sv = shap_values[:, :, 1]
1075+
else:
1076+
sv = shap_values
1077+
1078+
base = explainer.expected_value
1079+
if isinstance(base, (list, np.ndarray)):
1080+
base_value = float(base[1]) if len(base) > 1 else float(base[0])
1081+
else:
1082+
base_value = float(base)
1083+
1084+
base_rounded = round(base_value, 4)
1085+
1086+
def _make_entry(row_shap, row_values, j):
1087+
fval = row_values.iloc[j]
1088+
return {
1089+
"feature": feature_names[j],
1090+
"shap_value": round(float(row_shap[j]), 4),
1091+
"value": float(fval) if isinstance(fval, (int, float, np.integer, np.floating)) else str(fval),
1092+
}
1093+
1094+
explanations = []
1095+
for i in range(len(X_data)):
1096+
row_shap = sv[i]
1097+
row_values = X_data.iloc[i]
1098+
1099+
# Use argsort to find top contributors without building a full list
1100+
pos_indices = np.argsort(-row_shap)[:top_n]
1101+
pos_indices = pos_indices[row_shap[pos_indices] > 0]
1102+
1103+
neg_indices = np.argsort(row_shap)[:top_n]
1104+
neg_indices = neg_indices[row_shap[neg_indices] < 0]
1105+
1106+
explanations.append({
1107+
"base_value": base_rounded,
1108+
"top_positive": [_make_entry(row_shap, row_values, j) for j in pos_indices],
1109+
"top_negative": [_make_entry(row_shap, row_values, j) for j in neg_indices],
1110+
})
1111+
1112+
return explanations
1113+
1114+
1115+
# Models to explain with SHAP (all 4 XGBoost/RF classifiers)
1116+
shap_targets = {
1117+
"retention": (retention_model, X_full_retention, retention_features),
1118+
"gateway_math": (gateway_math_model, X_full_gateway_math, gateway_math_features),
1119+
"gateway_english": (gateway_english_model, X_full_gateway_english, gateway_english_features),
1120+
"low_gpa": (low_gpa_model, X_gpa_clean, gpa_features),
1121+
}
1122+
1123+
# Build per-student dicts in a single pass, discarding each model's list promptly
1124+
student_shap_dicts = [{} for _ in range(len(df))]
1125+
1126+
for label, (model, X_data, features) in shap_targets.items():
1127+
print(f"\nComputing SHAP explanations for {label} model...")
1128+
explanations = compute_shap_explanations(model, X_data, features)
1129+
for i, ex in enumerate(explanations):
1130+
student_shap_dicts[i][label] = ex
1131+
print(f" ✓ {len(explanations)} student explanations generated")
1132+
if explanations:
1133+
ex = explanations[0]
1134+
print(f" Sample (student 0): base_value={ex['base_value']}")
1135+
for f in ex['top_positive'][:3]:
1136+
print(f" ↑ {f['feature']}: +{f['shap_value']}")
1137+
for f in ex['top_negative'][:3]:
1138+
print(f" ↓ {f['feature']}: {f['shap_value']}")
1139+
1140+
df['shap_explanations'] = [json.dumps(d) for d in student_shap_dicts]
1141+
print(f"✓ SHAP explanations attached as JSON column ({len(df):,} students)")
1142+
10491143
# ============================================================================
10501144
# STEP 11: SAVE PREDICTIONS TO STUDENT-LEVEL FILE
10511145
# ============================================================================
@@ -1063,7 +1157,8 @@ def assign_alert_level(risk_score):
10631157
'prob_no_credential', 'prob_certificate', 'prob_associate', 'prob_bachelor',
10641158
'gateway_math_probability', 'gateway_math_prediction', 'gateway_math_risk',
10651159
'gateway_english_probability', 'gateway_english_prediction', 'gateway_english_risk',
1066-
'low_gpa_probability', 'low_gpa_prediction', 'academic_risk_level'
1160+
'low_gpa_probability', 'low_gpa_prediction', 'academic_risk_level',
1161+
'shap_explanations'
10671162
]
10681163

10691164
predictions_df = df[prediction_columns].copy()

ai_model/generate_readiness_scores.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -447,13 +447,17 @@ def score_student(row) -> dict:
447447
# LLM Enrichment (optional)
448448
# ============================================================================
449449

450-
def enrich_with_llm(record: dict, model: str) -> dict:
450+
def enrich_with_llm(record: dict, model: str, shap_data: dict = None) -> dict:
451451
"""
452452
Replace rationale and suggested_actions with LLM-generated content.
453453
Only called for medium/low readiness students.
454454
Input is the FERPA-safe profile — no PII sent to any external service.
455455
Returns the record with enriched text fields (score unchanged).
456456
457+
When shap_data is provided (from the ML pipeline's SHAP step), the prompt
458+
includes per-model feature attribution so the LLM can ground its
459+
explanation in what the models actually learned.
460+
457461
Provider is determined by the model string:
458462
"gpt-4o-mini" -> OpenAI (requires OPENAI_API_KEY)
459463
"ollama/llama3.2:3b" -> local Ollama (no key needed)
@@ -469,6 +473,25 @@ def enrich_with_llm(record: dict, model: str) -> dict:
469473
profile = json.loads(record["input_features"]) if isinstance(record["input_features"], str) else record["input_features"]
470474
risk_factors = json.loads(record["risk_factors"]) if isinstance(record["risk_factors"], str) else []
471475

476+
# Build SHAP context section if available
477+
shap_section = ""
478+
if shap_data:
479+
shap_lines = []
480+
for model_name, attrs in shap_data.items():
481+
shap_lines.append(f"\n {model_name} model (base prediction: {attrs.get('base_value', 'N/A')}):")
482+
for f in attrs.get("top_positive", []):
483+
shap_lines.append(f" ↑ {f['feature']} = {f['value']} (pushes prediction UP by {f['shap_value']})")
484+
for f in attrs.get("top_negative", []):
485+
shap_lines.append(f" ↓ {f['feature']} = {f['value']} (pushes prediction DOWN by {abs(f['shap_value'])})")
486+
shap_section = f"""
487+
488+
ML Model Feature Attribution (SHAP — shows which features drive each prediction):
489+
{''.join(shap_lines)}
490+
491+
IMPORTANT: Use these SHAP values to ground your explanation. Tell the advisor
492+
which specific factors are most responsible for this student's risk level,
493+
citing the magnitude. Do not speculate beyond what the models show."""
494+
472495
prompt = f"""You are an academic advisor assistant at Bishop State Community College.
473496
A student has a readiness score of {record['readiness_score']:.2f} ({record['readiness_level']} readiness).
474497
@@ -484,10 +507,10 @@ def enrich_with_llm(record: dict, model: str) -> dict:
484507
- Retention probability: {profile.get('retention_probability')}
485508
486509
Identified risk factors:
487-
{chr(10).join(f'- {f}' for f in risk_factors)}
510+
{chr(10).join(f'- {f}' for f in risk_factors)}{shap_section}
488511
489512
Write two things:
490-
1. RATIONALE: A 2-sentence explanation of this student's readiness score for an advisor.
513+
1. RATIONALE: A 2-3 sentence explanation of this student's readiness score for an advisor. If SHAP data is available, cite the top contributing factors by name and magnitude.
491514
2. ACTIONS: A JSON array of 3-5 specific, actionable intervention recommendations (strings only).
492515
493516
Format your response exactly as:
@@ -588,7 +611,15 @@ def main():
588611
record["generation_ms"] = elapsed_ms
589612
record["run_id"] = run_id
590613
if args.enrich_with_llm and record["readiness_level"] in ("medium", "low"):
591-
record = enrich_with_llm(record, args.llm_model)
614+
# Pass SHAP data if the shap_explanations column exists
615+
shap_data = None
616+
shap_raw = row.get("shap_explanations")
617+
if shap_raw and str(shap_raw) not in ("", "nan", "None"):
618+
try:
619+
shap_data = json.loads(shap_raw) if isinstance(shap_raw, str) else shap_raw
620+
except (json.JSONDecodeError, TypeError):
621+
pass
622+
record = enrich_with_llm(record, args.llm_model, shap_data=shap_data)
592623
records.append(record)
593624
except Exception as e:
594625
errors += 1

codebenders-dashboard/app/api/students/[guid]/route.ts

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@ import { type NextRequest, NextResponse } from "next/server"
22
import { getPool } from "@/lib/db"
33
import { canAccess, type Role } from "@/lib/roles"
44

5+
function safeParse<T>(raw: unknown, fallback: T): T {
6+
if (!raw) return fallback
7+
try {
8+
return typeof raw === "string" ? JSON.parse(raw) : (raw as T)
9+
} catch {
10+
return fallback
11+
}
12+
}
13+
514
export async function GET(
615
request: NextRequest,
716
{ params }: { params: Promise<{ guid: string }> }
@@ -29,6 +38,7 @@ export async function GET(
2938
ROUND((s.low_gpa_probability * 100)::numeric, 1) AS gpa_risk_pct,
3039
ROUND(s.predicted_time_to_credential::numeric, 1) AS time_to_credential,
3140
s.predicted_credential_label AS credential_type,
41+
s.shap_explanations,
3242
ROUND((r.readiness_score * 100)::numeric, 1) AS readiness_pct,
3343
r.readiness_level,
3444
r.rationale,
@@ -53,8 +63,9 @@ export async function GET(
5363
const row = result.rows[0]
5464
return NextResponse.json({
5565
...row,
56-
risk_factors: row.risk_factors ? JSON.parse(row.risk_factors) : [],
57-
suggested_actions: row.suggested_actions ? JSON.parse(row.suggested_actions) : [],
66+
risk_factors: safeParse(row.risk_factors, []),
67+
suggested_actions: safeParse(row.suggested_actions, []),
68+
shap_explanations: safeParse(row.shap_explanations, null),
5869
})
5970
} catch (error) {
6071
console.error("Student detail fetch error:", error)

0 commit comments

Comments
 (0)