2222)
2323from sklearn .ensemble import RandomForestClassifier , RandomForestRegressor
2424import xgboost as xgb
25+ import shap
26+ import json
2527from datetime import datetime
2628import warnings
2729warnings .filterwarnings ('ignore' )
@@ -1046,6 +1048,98 @@ def assign_alert_level(risk_score):
10461048
10471049print (f"Low GPA predictions generated" )
10481050
1051+ # ============================================================================
1052+ # STEP 10b: PER-STUDENT SHAP EXPLANATIONS
1053+ # ============================================================================
1054+ print ("\n " + "=" * 80 )
1055+ print ("STEP 10b: COMPUTING PER-STUDENT SHAP EXPLANATIONS" )
1056+ print ("=" * 80 )
1057+
1058+
1059+ def compute_shap_explanations (model , X_data , feature_names , top_n = 5 ):
1060+ """
1061+ Compute per-student SHAP values using TreeExplainer.
1062+
1063+ For binary classifiers, uses class-1 (positive outcome) SHAP values.
1064+ Returns top N positive/negative contributors per student.
1065+ """
1066+ explainer = shap .TreeExplainer (model )
1067+ shap_values = explainer .shap_values (X_data )
1068+
1069+ # Binary classifiers: RandomForest returns list [class_0, class_1],
1070+ # some models return 3D (samples, features, classes), XGBoost returns 2D.
1071+ if isinstance (shap_values , list ):
1072+ sv = shap_values [1 ]
1073+ elif shap_values .ndim == 3 :
1074+ sv = shap_values [:, :, 1 ]
1075+ else :
1076+ sv = shap_values
1077+
1078+ base = explainer .expected_value
1079+ if isinstance (base , (list , np .ndarray )):
1080+ base_value = float (base [1 ]) if len (base ) > 1 else float (base [0 ])
1081+ else :
1082+ base_value = float (base )
1083+
1084+ base_rounded = round (base_value , 4 )
1085+
1086+ def _make_entry (row_shap , row_values , j ):
1087+ fval = row_values .iloc [j ]
1088+ return {
1089+ "feature" : feature_names [j ],
1090+ "shap_value" : round (float (row_shap [j ]), 4 ),
1091+ "value" : float (fval ) if isinstance (fval , (int , float , np .integer , np .floating )) else str (fval ),
1092+ }
1093+
1094+ explanations = []
1095+ for i in range (len (X_data )):
1096+ row_shap = sv [i ]
1097+ row_values = X_data .iloc [i ]
1098+
1099+ # Use argsort to find top contributors without building a full list
1100+ pos_indices = np .argsort (- row_shap )[:top_n ]
1101+ pos_indices = pos_indices [row_shap [pos_indices ] > 0 ]
1102+
1103+ neg_indices = np .argsort (row_shap )[:top_n ]
1104+ neg_indices = neg_indices [row_shap [neg_indices ] < 0 ]
1105+
1106+ explanations .append ({
1107+ "base_value" : base_rounded ,
1108+ "top_positive" : [_make_entry (row_shap , row_values , j ) for j in pos_indices ],
1109+ "top_negative" : [_make_entry (row_shap , row_values , j ) for j in neg_indices ],
1110+ })
1111+
1112+ return explanations
1113+
1114+
1115+ # Models to explain with SHAP (all 4 XGBoost/RF classifiers)
1116+ shap_targets = {
1117+ "retention" : (retention_model , X_full_retention , retention_features ),
1118+ "gateway_math" : (gateway_math_model , X_full_gateway_math , gateway_math_features ),
1119+ "gateway_english" : (gateway_english_model , X_full_gateway_english , gateway_english_features ),
1120+ "low_gpa" : (low_gpa_model , X_gpa_clean , gpa_features ),
1121+ }
1122+
1123+ # Build per-student dicts in a single pass, discarding each model's list promptly
1124+ student_shap_dicts = [{} for _ in range (len (df ))]
1125+
1126+ for label , (model , X_data , features ) in shap_targets .items ():
1127+ print (f"\n Computing SHAP explanations for { label } model..." )
1128+ explanations = compute_shap_explanations (model , X_data , features )
1129+ for i , ex in enumerate (explanations ):
1130+ student_shap_dicts [i ][label ] = ex
1131+ print (f" ✓ { len (explanations )} student explanations generated" )
1132+ if explanations :
1133+ ex = explanations [0 ]
1134+ print (f" Sample (student 0): base_value={ ex ['base_value' ]} " )
1135+ for f in ex ['top_positive' ][:3 ]:
1136+ print (f" ↑ { f ['feature' ]} : +{ f ['shap_value' ]} " )
1137+ for f in ex ['top_negative' ][:3 ]:
1138+ print (f" ↓ { f ['feature' ]} : { f ['shap_value' ]} " )
1139+
1140+ df ['shap_explanations' ] = [json .dumps (d ) for d in student_shap_dicts ]
1141+ print (f"✓ SHAP explanations attached as JSON column ({ len (df ):,} students)" )
1142+
10491143# ============================================================================
10501144# STEP 11: SAVE PREDICTIONS TO STUDENT-LEVEL FILE
10511145# ============================================================================
@@ -1063,7 +1157,8 @@ def assign_alert_level(risk_score):
10631157 'prob_no_credential' , 'prob_certificate' , 'prob_associate' , 'prob_bachelor' ,
10641158 'gateway_math_probability' , 'gateway_math_prediction' , 'gateway_math_risk' ,
10651159 'gateway_english_probability' , 'gateway_english_prediction' , 'gateway_english_risk' ,
1066- 'low_gpa_probability' , 'low_gpa_prediction' , 'academic_risk_level'
1160+ 'low_gpa_probability' , 'low_gpa_prediction' , 'academic_risk_level' ,
1161+ 'shap_explanations'
10671162]
10681163
10691164predictions_df = df [prediction_columns ].copy ()
0 commit comments