Skip to content

Commit baff2b2

Browse files
committed
refactor: simplify SHAP computation and unify JSON parsing
- Remove all_contributions from SHAP output (built but never consumed) - Use np.argsort for top-N selection instead of building full list + sort - Merge JSON assembly into model loop (single pass, lower peak memory) - Move shap/json imports to module top, drop _json alias - Remove unused model_label parameter from compute_shap_explanations - Remove unnecessary hasattr guards (X_data is always a DataFrame) - Extract safeParse helper in student API for consistent JSON handling
1 parent dfc8d09 commit baff2b2

2 files changed

Lines changed: 47 additions & 67 deletions

File tree

ai_model/complete_ml_pipeline.py

Lines changed: 35 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
)
2323
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
2424
import xgboost as xgb
25+
import shap
26+
import json
2527
from datetime import datetime
2628
import warnings
2729
warnings.filterwarnings('ignore')
@@ -1053,66 +1055,58 @@ def assign_alert_level(risk_score):
10531055
print("STEP 10b: COMPUTING PER-STUDENT SHAP EXPLANATIONS")
10541056
print("=" * 80)
10551057

1056-
import shap
1057-
import json as _json
10581058

1059-
def compute_shap_explanations(model, X_data, feature_names, model_label, top_n=5):
1059+
def compute_shap_explanations(model, X_data, feature_names, top_n=5):
10601060
"""
10611061
Compute per-student SHAP values using TreeExplainer.
10621062
10631063
For binary classifiers, uses class-1 (positive outcome) SHAP values.
1064-
Returns top N positive/negative contributors per student plus the full
1065-
SHAP vector for downstream use by the fine-tuned explainer.
1064+
Returns top N positive/negative contributors per student.
10661065
"""
10671066
explainer = shap.TreeExplainer(model)
10681067
shap_values = explainer.shap_values(X_data)
10691068

1070-
# Binary classifiers: shap_values may be a list [class_0, class_1] (RandomForest),
1071-
# a 3D array (samples, features, classes), or a 2D array (XGBoost default).
1069+
# Binary classifiers: RandomForest returns list [class_0, class_1],
1070+
# some models return 3D (samples, features, classes), XGBoost returns 2D.
10721071
if isinstance(shap_values, list):
10731072
sv = shap_values[1]
10741073
elif shap_values.ndim == 3:
10751074
sv = shap_values[:, :, 1]
10761075
else:
10771076
sv = shap_values
10781077

1079-
# Base value — expected model output before any feature contributions
10801078
base = explainer.expected_value
10811079
if isinstance(base, (list, np.ndarray)):
10821080
base_value = float(base[1]) if len(base) > 1 else float(base[0])
10831081
else:
10841082
base_value = float(base)
10851083

1084+
base_rounded = round(base_value, 4)
1085+
1086+
def _make_entry(row_shap, row_values, j):
1087+
fval = row_values.iloc[j]
1088+
return {
1089+
"feature": feature_names[j],
1090+
"shap_value": round(float(row_shap[j]), 4),
1091+
"value": float(fval) if isinstance(fval, (int, float, np.integer, np.floating)) else str(fval),
1092+
}
1093+
10861094
explanations = []
10871095
for i in range(len(X_data)):
10881096
row_shap = sv[i]
1089-
row_values = X_data.iloc[i] if hasattr(X_data, 'iloc') else X_data[i]
1090-
1091-
# Build (feature_name, shap_value, feature_value) tuples
1092-
feature_contribs = []
1093-
for j, fname in enumerate(feature_names):
1094-
fval = row_values.iloc[j] if hasattr(row_values, 'iloc') else row_values[j]
1095-
feature_contribs.append({
1096-
"feature": fname,
1097-
"shap_value": round(float(row_shap[j]), 4),
1098-
"value": float(fval) if isinstance(fval, (int, float, np.integer, np.floating)) else str(fval),
1099-
})
1100-
1101-
sorted_pos = sorted(
1102-
[f for f in feature_contribs if f["shap_value"] > 0],
1103-
key=lambda x: x["shap_value"], reverse=True
1104-
)[:top_n]
1105-
1106-
sorted_neg = sorted(
1107-
[f for f in feature_contribs if f["shap_value"] < 0],
1108-
key=lambda x: x["shap_value"]
1109-
)[:top_n]
1097+
row_values = X_data.iloc[i]
1098+
1099+
# Use argsort to find top contributors without building a full list
1100+
pos_indices = np.argsort(-row_shap)[:top_n]
1101+
pos_indices = pos_indices[row_shap[pos_indices] > 0]
1102+
1103+
neg_indices = np.argsort(row_shap)[:top_n]
1104+
neg_indices = neg_indices[row_shap[neg_indices] < 0]
11101105

11111106
explanations.append({
1112-
"base_value": round(base_value, 4),
1113-
"top_positive": sorted_pos,
1114-
"top_negative": sorted_neg,
1115-
"all_contributions": feature_contribs,
1107+
"base_value": base_rounded,
1108+
"top_positive": [_make_entry(row_shap, row_values, j) for j in pos_indices],
1109+
"top_negative": [_make_entry(row_shap, row_values, j) for j in neg_indices],
11161110
})
11171111

11181112
return explanations
@@ -1126,11 +1120,14 @@ def compute_shap_explanations(model, X_data, feature_names, model_label, top_n=5
11261120
"low_gpa": (low_gpa_model, X_gpa_clean, gpa_features),
11271121
}
11281122

1129-
shap_results = {}
1123+
# Build per-student dicts in a single pass, discarding each model's list promptly
1124+
student_shap_dicts = [{} for _ in range(len(df))]
1125+
11301126
for label, (model, X_data, features) in shap_targets.items():
11311127
print(f"\nComputing SHAP explanations for {label} model...")
1132-
explanations = compute_shap_explanations(model, X_data, features, label)
1133-
shap_results[label] = explanations
1128+
explanations = compute_shap_explanations(model, X_data, features)
1129+
for i, ex in enumerate(explanations):
1130+
student_shap_dicts[i][label] = ex
11341131
print(f" ✓ {len(explanations)} student explanations generated")
11351132
if explanations:
11361133
ex = explanations[0]
@@ -1140,24 +1137,8 @@ def compute_shap_explanations(model, X_data, feature_names, model_label, top_n=5
11401137
for f in ex['top_negative'][:3]:
11411138
print(f" ↓ {f['feature']}: {f['shap_value']}")
11421139

1143-
# Attach SHAP explanations as JSON column on the main dataframe
1144-
# Stores only top contributors per model to keep DB size manageable
1145-
print("\nAttaching SHAP explanations to student dataframe...")
1146-
shap_json_col = []
1147-
for i in range(len(df)):
1148-
student_shap = {}
1149-
for label, explanations in shap_results.items():
1150-
if i < len(explanations):
1151-
ex = explanations[i]
1152-
student_shap[label] = {
1153-
"base_value": ex["base_value"],
1154-
"top_positive": ex["top_positive"],
1155-
"top_negative": ex["top_negative"],
1156-
}
1157-
shap_json_col.append(_json.dumps(student_shap))
1158-
1159-
df['shap_explanations'] = shap_json_col
1160-
print(f"✓ SHAP explanations attached as JSON column ({len(shap_json_col):,} students)")
1140+
df['shap_explanations'] = [json.dumps(d) for d in student_shap_dicts]
1141+
print(f"✓ SHAP explanations attached as JSON column ({len(df):,} students)")
11611142

11621143
# ============================================================================
11631144
# STEP 11: SAVE PREDICTIONS TO STUDENT-LEVEL FILE

codebenders-dashboard/app/api/students/[guid]/route.ts

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@ import { type NextRequest, NextResponse } from "next/server"
22
import { getPool } from "@/lib/db"
33
import { canAccess, type Role } from "@/lib/roles"
44

5+
function safeParse<T>(raw: unknown, fallback: T): T {
6+
if (!raw) return fallback
7+
try {
8+
return typeof raw === "string" ? JSON.parse(raw) : (raw as T)
9+
} catch {
10+
return fallback
11+
}
12+
}
13+
514
export async function GET(
615
request: NextRequest,
716
{ params }: { params: Promise<{ guid: string }> }
@@ -52,21 +61,11 @@ export async function GET(
5261
}
5362

5463
const row = result.rows[0]
55-
// Parse JSON string columns into objects for the frontend
56-
let shap = null
57-
if (row.shap_explanations) {
58-
try {
59-
shap = typeof row.shap_explanations === "string"
60-
? JSON.parse(row.shap_explanations)
61-
: row.shap_explanations
62-
} catch { shap = null }
63-
}
64-
6564
return NextResponse.json({
6665
...row,
67-
risk_factors: row.risk_factors ? JSON.parse(row.risk_factors) : [],
68-
suggested_actions: row.suggested_actions ? JSON.parse(row.suggested_actions) : [],
69-
shap_explanations: shap,
66+
risk_factors: safeParse(row.risk_factors, []),
67+
suggested_actions: safeParse(row.suggested_actions, []),
68+
shap_explanations: safeParse(row.shap_explanations, null),
7069
})
7170
} catch (error) {
7271
console.error("Student detail fetch error:", error)

0 commit comments

Comments
 (0)