2222)
2323from sklearn .ensemble import RandomForestClassifier , RandomForestRegressor
2424import xgboost as xgb
25+ import shap
26+ import json
2527from datetime import datetime
2628import warnings
2729warnings .filterwarnings ('ignore' )
@@ -1053,66 +1055,58 @@ def assign_alert_level(risk_score):
10531055print ("STEP 10b: COMPUTING PER-STUDENT SHAP EXPLANATIONS" )
10541056print ("=" * 80 )
10551057
1056- import shap
1057- import json as _json
10581058
1059- def compute_shap_explanations (model , X_data , feature_names , model_label , top_n = 5 ):
1059+ def compute_shap_explanations (model , X_data , feature_names , top_n = 5 ):
10601060 """
10611061 Compute per-student SHAP values using TreeExplainer.
10621062
10631063 For binary classifiers, uses class-1 (positive outcome) SHAP values.
1064- Returns top N positive/negative contributors per student plus the full
1065- SHAP vector for downstream use by the fine-tuned explainer.
1064+ Returns top N positive/negative contributors per student.
10661065 """
10671066 explainer = shap .TreeExplainer (model )
10681067 shap_values = explainer .shap_values (X_data )
10691068
1070- # Binary classifiers: shap_values may be a list [class_0, class_1] (RandomForest) ,
1071- # a 3D array (samples, features, classes), or a 2D array (XGBoost default) .
1069+ # Binary classifiers: RandomForest returns list [class_0, class_1],
1070+ # some models return 3D (samples, features, classes), XGBoost returns 2D.
10721071 if isinstance (shap_values , list ):
10731072 sv = shap_values [1 ]
10741073 elif shap_values .ndim == 3 :
10751074 sv = shap_values [:, :, 1 ]
10761075 else :
10771076 sv = shap_values
10781077
1079- # Base value — expected model output before any feature contributions
10801078 base = explainer .expected_value
10811079 if isinstance (base , (list , np .ndarray )):
10821080 base_value = float (base [1 ]) if len (base ) > 1 else float (base [0 ])
10831081 else :
10841082 base_value = float (base )
10851083
1084+ base_rounded = round (base_value , 4 )
1085+
1086+ def _make_entry (row_shap , row_values , j ):
1087+ fval = row_values .iloc [j ]
1088+ return {
1089+ "feature" : feature_names [j ],
1090+ "shap_value" : round (float (row_shap [j ]), 4 ),
1091+ "value" : float (fval ) if isinstance (fval , (int , float , np .integer , np .floating )) else str (fval ),
1092+ }
1093+
10861094 explanations = []
10871095 for i in range (len (X_data )):
10881096 row_shap = sv [i ]
1089- row_values = X_data .iloc [i ] if hasattr (X_data , 'iloc' ) else X_data [i ]
1090-
1091- # Build (feature_name, shap_value, feature_value) tuples
1092- feature_contribs = []
1093- for j , fname in enumerate (feature_names ):
1094- fval = row_values .iloc [j ] if hasattr (row_values , 'iloc' ) else row_values [j ]
1095- feature_contribs .append ({
1096- "feature" : fname ,
1097- "shap_value" : round (float (row_shap [j ]), 4 ),
1098- "value" : float (fval ) if isinstance (fval , (int , float , np .integer , np .floating )) else str (fval ),
1099- })
1100-
1101- sorted_pos = sorted (
1102- [f for f in feature_contribs if f ["shap_value" ] > 0 ],
1103- key = lambda x : x ["shap_value" ], reverse = True
1104- )[:top_n ]
1105-
1106- sorted_neg = sorted (
1107- [f for f in feature_contribs if f ["shap_value" ] < 0 ],
1108- key = lambda x : x ["shap_value" ]
1109- )[:top_n ]
1097+ row_values = X_data .iloc [i ]
1098+
1099+ # Use argsort to find top contributors without building a full list
1100+ pos_indices = np .argsort (- row_shap )[:top_n ]
1101+ pos_indices = pos_indices [row_shap [pos_indices ] > 0 ]
1102+
1103+ neg_indices = np .argsort (row_shap )[:top_n ]
1104+ neg_indices = neg_indices [row_shap [neg_indices ] < 0 ]
11101105
11111106 explanations .append ({
1112- "base_value" : round (base_value , 4 ),
1113- "top_positive" : sorted_pos ,
1114- "top_negative" : sorted_neg ,
1115- "all_contributions" : feature_contribs ,
1107+ "base_value" : base_rounded ,
1108+ "top_positive" : [_make_entry (row_shap , row_values , j ) for j in pos_indices ],
1109+ "top_negative" : [_make_entry (row_shap , row_values , j ) for j in neg_indices ],
11161110 })
11171111
11181112 return explanations
@@ -1126,11 +1120,14 @@ def compute_shap_explanations(model, X_data, feature_names, model_label, top_n=5
11261120 "low_gpa" : (low_gpa_model , X_gpa_clean , gpa_features ),
11271121}
11281122
1129- shap_results = {}
1123+ # Build per-student dicts in a single pass, discarding each model's list promptly
1124+ student_shap_dicts = [{} for _ in range (len (df ))]
1125+
11301126for label , (model , X_data , features ) in shap_targets .items ():
11311127 print (f"\n Computing SHAP explanations for { label } model..." )
1132- explanations = compute_shap_explanations (model , X_data , features , label )
1133- shap_results [label ] = explanations
1128+ explanations = compute_shap_explanations (model , X_data , features )
1129+ for i , ex in enumerate (explanations ):
1130+ student_shap_dicts [i ][label ] = ex
11341131 print (f" ✓ { len (explanations )} student explanations generated" )
11351132 if explanations :
11361133 ex = explanations [0 ]
@@ -1140,24 +1137,8 @@ def compute_shap_explanations(model, X_data, feature_names, model_label, top_n=5
11401137 for f in ex ['top_negative' ][:3 ]:
11411138 print (f" ↓ { f ['feature' ]} : { f ['shap_value' ]} " )
11421139
1143- # Attach SHAP explanations as JSON column on the main dataframe
1144- # Stores only top contributors per model to keep DB size manageable
1145- print ("\n Attaching SHAP explanations to student dataframe..." )
1146- shap_json_col = []
1147- for i in range (len (df )):
1148- student_shap = {}
1149- for label , explanations in shap_results .items ():
1150- if i < len (explanations ):
1151- ex = explanations [i ]
1152- student_shap [label ] = {
1153- "base_value" : ex ["base_value" ],
1154- "top_positive" : ex ["top_positive" ],
1155- "top_negative" : ex ["top_negative" ],
1156- }
1157- shap_json_col .append (_json .dumps (student_shap ))
1158-
1159- df ['shap_explanations' ] = shap_json_col
1160- print (f"✓ SHAP explanations attached as JSON column ({ len (shap_json_col ):,} students)" )
1140+ df ['shap_explanations' ] = [json .dumps (d ) for d in student_shap_dicts ]
1141+ print (f"✓ SHAP explanations attached as JSON column ({ len (df ):,} students)" )
11611142
11621143# ============================================================================
11631144# STEP 11: SAVE PREDICTIONS TO STUDENT-LEVEL FILE
0 commit comments