Skip to content

Commit ac0dc24

Browse files
committed
fix: cast np.float64 metrics to float for psycopg2, force labels=[0,1] in confusion matrix
1 parent 4a7a1c4 commit ac0dc24

2 files changed

Lines changed: 14 additions & 11 deletions

File tree

ai_model/complete_ml_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,7 @@ def assign_alert_level(risk_score):
783783
print(f"F1-Score: {math_f1:.4f}")
784784

785785
print("\nConfusion Matrix:")
786-
cm = confusion_matrix(y_test, y_pred)
786+
cm = confusion_matrix(y_test, y_pred, labels=[0, 1])
787787
print(f" Predicted")
788788
print(f" No Pass Pass")
789789
print(f"Actual No {cm[0,0]:6d} {cm[0,1]:6d}")
@@ -894,7 +894,7 @@ def assign_alert_level(risk_score):
894894
print(f"F1-Score: {english_f1:.4f}")
895895

896896
print("\nConfusion Matrix:")
897-
cm = confusion_matrix(y_test, y_pred)
897+
cm = confusion_matrix(y_test, y_pred, labels=[0, 1])
898898
print(f" Predicted")
899899
print(f" No Pass Pass")
900900
print(f"Actual No {cm[0,0]:6d} {cm[0,1]:6d}")
@@ -1008,7 +1008,7 @@ def assign_alert_level(risk_score):
10081008
print(f"F1-Score: {gpa_f1:.4f}")
10091009

10101010
print("\nConfusion Matrix:")
1011-
cm = confusion_matrix(y_test, y_pred)
1011+
cm = confusion_matrix(y_test, y_pred, labels=[0, 1])
10121012
print(f" Predicted")
10131013
print(f" GPA>=2.0 GPA<2.0")
10141014
print(f"Actual >=2.0 {cm[0,0]:6d} {cm[0,1]:6d}")

operations/db_utils.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -191,17 +191,20 @@ def save_model_performance(model_name, model_type, metrics, notes=""):
191191
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
192192
"""
193193

194+
def _f(v):
195+
return float(v) if v is not None else None
196+
194197
values = (
195198
model_name,
196199
model_type,
197-
metrics.get('accuracy'),
198-
metrics.get('precision'),
199-
metrics.get('recall'),
200-
metrics.get('f1'),
201-
metrics.get('auc_roc'),
202-
metrics.get('rmse'),
203-
metrics.get('mae'),
204-
metrics.get('r2_score'),
200+
_f(metrics.get('accuracy')),
201+
_f(metrics.get('precision')),
202+
_f(metrics.get('recall')),
203+
_f(metrics.get('f1')),
204+
_f(metrics.get('auc_roc')),
205+
_f(metrics.get('rmse')),
206+
_f(metrics.get('mae')),
207+
_f(metrics.get('r2_score')),
205208
notes
206209
)
207210

0 commit comments

Comments
 (0)