# %%
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import confusion_matrix, roc_auc_score
from sklearn.ensemble import GradientBoostingClassifier
from imblearn.over_sampling import SMOTE
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from catboost import CatBoostClassifier
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from sklearn.neural_network import MLPClassifier
import shap

plt.rcParams['font.family'] = 'Times New Roman'

data = pd.read_csv("data1.csv")

X = data[['EC', 'AV', 'FR', 'SSA', 'Vpore', 'ID/IG', 'N', 'O']]
data['CN_class'] = data['CN'].apply(lambda x: 'A' if x <= 40 else 'B')
y = data['CN_class']

scaler = MinMaxScaler()
X_normalized = scaler.fit_transform(X)

X_train, X_test, y_train, y_test = train_test_split(
    X_normalized, y, test_size=0.2, stratify=y, random_state=42
)

smote = SMOTE(random_state=42)
X_train_resampled, y_train_resampled = smote.fit_resample(X_train, y_train)

model = CatBoostClassifier(
    depth=5,
    iterations=200,
    l2_leaf_reg=3,
    learning_rate=0.1               
)

model.fit(X_train_resampled, y_train_resampled)

y_pred_proba_train = model.predict_proba(X_train)[:, 1]
y_pred_proba_test = model.predict_proba(X_test)[:, 1]

def predict_with_threshold(proba, threshold=0.5):
    return np.where(proba >= threshold, 'B', 'A')

y_pred_train = predict_with_threshold(y_pred_proba_train, 0.6)
y_pred_test = predict_with_threshold(y_pred_proba_test, 0.6)

conf_train = confusion_matrix(y_train, y_pred_train, labels=['A', 'B'])
conf_test = confusion_matrix(y_test, y_pred_test, labels=['A', 'B'])

def plot_cm(cm, title, fontsize=18, annot_fontsize=18):
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['A', 'B'], yticklabels=['A', 'B'],annot_kws={"size": annot_fontsize})
    plt.title(title, fontsize=fontsize)
    plt.xlabel('Predicted', fontsize=16)
    plt.ylabel('Actual', fontsize=16)
    plt.show()

plot_cm(conf_train, "Train Confusion Matrix")
plot_cm(conf_test, "Test Confusion Matrix")

def calc_tpr_fpr(cm):
    TP, FN = cm[0, 0], cm[0, 1]
    FP, TN = cm[1, 0], cm[1, 1]
    TPR = TP / (TP + FN) if TP + FN > 0 else 0
    FPR = FP / (FP + TN) if FP + TN > 0 else 0
    return TPR, FPR

tpr_train, fpr_train = calc_tpr_fpr(conf_train)
tpr_test, fpr_test = calc_tpr_fpr(conf_test)

print(f"Train TPR: {tpr_train:.2f}, FPR: {fpr_train:.2f}")
print(f"Test TPR: {tpr_test:.2f}, FPR: {fpr_test:.2f}")
print(f"AUC Score: {roc_auc_score(y_test, y_pred_proba_test):.2f}")

explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_train)

shap.summary_plot(shap_values, X_train, feature_names=['EC', 'AV', 'FR', 'SSA', 'Vpore', 'ID/IG', 'N', 'O'],show=False)

ax = plt.gca()

ax.tick_params(axis='x', labelsize=18) 
ax.tick_params(axis='y', labelsize=18) 

plt.show()

mean_shap_values = np.mean(np.abs(shap_values), axis=0)

feature_names = ['EC', 'AV', 'FR', 'SSA', 'Vpore', 'ID/IG', 'N', 'O']

shap_feature_importance = sorted(zip(feature_names, mean_shap_values), key=lambda x: x[1], reverse=True)

sorted_features = [item[0] for item in shap_feature_importance]
sorted_shap_values = [item[1] for item in shap_feature_importance]

plt.figure(figsize=(8, 6))
plt.barh(sorted_features, sorted_shap_values, color='skyblue')
plt.xlabel('Average SHAP Value', fontsize=24)
plt.xticks(fontsize=20)  
plt.yticks(fontsize=20)  
plt.gca().invert_yaxis()  
plt.show()


