# visualization.py - 客户流失预测模型可视化(直接运行即可) import joblib import matplotlib.pyplot as plt import seaborn as sns import pandas as pd from sklearn.metrics import confusion_matrix, roc_curve, auc # -------------------------- 基础设置(解决中文显示、图表样式)-------------------------- plt.rcParams['font.sans-serif'] = ['SimHei'] # 中文显示(Windows) # 如果是Mac/Linux,替换为:plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题 plt.style.use('seaborn-v0_8-whitegrid') # 图表样式(清爽易读) # -------------------------- 1. 加载模型和数据(复用已有逻辑)-------------------------- def load_model_and_data(): """加载训练好的模型和测试集数据""" # 加载模型(确保模型文件路径正确) try: model = joblib.load('telco_churn_model.pkl') print("✅ 模型加载成功") except FileNotFoundError: raise FileNotFoundError("❌ 未找到模型文件!请先运行 src/model.py 训练模型") # 加载并切分数据(复用 src/data.py 的逻辑,避免重复代码) try: from src.data import load_data, split_data df = load_data() X_train, X_test, y_train, y_test = split_data(df) print("✅ 测试集数据加载成功(共1409条)") return model, X_test, y_test except ImportError: raise ImportError("❌ 未找到 src/data.py!请确保项目目录结构正确") # -------------------------- 2. 特征重要性TOP10可视化(核心业务洞察)-------------------------- def plot_feature_importance(model): """绘制特征重要性TOP10图表""" # 提取预处理后的特征名和重要性得分 preprocessor = model.named_steps['preprocessor'] feature_names = preprocessor.get_feature_names_out() feature_importance = model.named_steps['classifier'].feature_importances_ # 整理数据(排序+取TOP10,简化特征名方便显示) feature_df = pd.DataFrame({ '特征名': feature_names, '重要性': feature_importance }).sort_values('重要性', ascending=False).head(10) # 简化特征名(原特征名太长,图表显示优化) feature_name_map = { 'tenure': '客户在网时长', 'TotalCharges': '总消费金额', 'MonthlyCharges': '月消费金额', 'Contract_Two year': '合约期-2年', 'InternetService_Fiber optic': '网络类型-光纤', 'PaymentMethod_Electronic check': '支付方式-电子支票', 'Contract_One year': '合约期-1年', 'OnlineSecurity_Yes': '在线安全服务-有', 'TechSupport_Yes': '技术支持-有', 'PaperlessBilling_Yes': '电子账单-是' } feature_df['简化特征名'] = feature_df['特征名'].map(lambda x: feature_name_map.get(x, x[:15])) # 兜底避免报错 # 绘制水平条形图(更易读) fig, ax = plt.subplots(figsize=(10, 6)) sns.barplot( x='重要性', y='简化特征名', data=feature_df, palette='viridis_r', ax=ax # 颜色渐变(反向,重要性越高颜色越深) ) # 图表美化(标题、标签、数值标注) ax.set_title('客户流失预测 - 特征重要性TOP10', fontsize=16, fontweight='bold', pad=20) ax.set_xlabel('重要性得分', fontsize=12) ax.set_ylabel('特征', fontsize=12) ax.tick_params(axis='y', labelsize=10) # 在条形图上添加数值(直观展示得分) for i, v in enumerate(feature_df['重要性']): ax.text(v + 0.002, i, f'{v:.3f}', va='center', fontsize=9) # 保存图表(高清,可直接插入PPT) plt.tight_layout() plt.savefig('特征重要性TOP10.png', dpi=300, bbox_inches='tight') print("✅ 特征重要性图表已保存为:特征重要性TOP10.png") # -------------------------- 3. 混淆矩阵可视化(模型效果直观展示)-------------------------- def plot_confusion_matrix(model, X_test, y_test): """绘制混淆矩阵(展示模型预测准确率、漏判/误判情况)""" # 生成预测结果 y_pred = model.predict(X_test) # 计算混淆矩阵 cm = confusion_matrix(y_test, y_pred) # 混淆矩阵标签(0=未流失,1=流失) labels = ['未流失', '流失'] # 绘制热力图 fig, ax = plt.subplots(figsize=(8, 6)) sns.heatmap( cm, annot=True, fmt='d', cmap='Blues', # fmt='d' 显示整数 xticklabels=labels, yticklabels=labels, ax=ax, cbar_kws={'label': '客户数量'} # 颜色条标签 ) # 图表美化 ax.set_title('客户流失预测 - 混淆矩阵', fontsize=16, fontweight='bold', pad=20) ax.set_xlabel('预测标签', fontsize=12) ax.set_ylabel('真实标签', fontsize=12) # 添加统计信息(准确率、流失识别率) total = cm.sum() accuracy = (cm[0,0] + cm[1,1]) / total recall_churn = cm[1,1] / (cm[1,0] + cm[1,1]) # 流失客户识别率 ax.text(0.5, -0.15, f'准确率:{accuracy:.3f} | 流失识别率:{recall_churn:.3f}', ha='center', transform=ax.transAxes, fontsize=11) # 保存图表 plt.tight_layout() plt.savefig('混淆矩阵.png', dpi=300, bbox_inches='tight') print("✅ 混淆矩阵图表已保存为:混淆矩阵.png") # -------------------------- 4. 可选:ROC曲线可视化(进阶模型评估)-------------------------- def plot_roc_curve(model, X_test, y_test): """绘制ROC曲线(展示模型区分能力,AUC值)""" # 生成预测概率 y_pred_proba = model.predict_proba(X_test)[:, 1] # 取流失(1类)的概率 fpr, tpr, _ = roc_curve(y_test, y_pred_proba) roc_auc = auc(fpr, tpr) # 绘制ROC曲线 fig, ax = plt.subplots(figsize=(8, 6)) ax.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC曲线 (AUC = {roc_auc:.3f})') ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='随机猜测') # 图表美化 ax.set_xlim([0.0, 1.0]) ax.set_ylim([0.0, 1.05]) ax.set_title('客户流失预测 - ROC曲线', fontsize=16, fontweight='bold', pad=20) ax.set_xlabel('假阳性率(误判为流失)', fontsize=12) ax.set_ylabel('真阳性率(正确识别流失)', fontsize=12) ax.legend(loc="lower right", fontsize=11) ax.grid(True, alpha=0.3) # 保存图表 plt.tight_layout() plt.savefig('ROC曲线.png', dpi=300, bbox_inches='tight') print("✅ ROC曲线图表已保存为:ROC曲线.png") # -------------------------- 主函数(一键运行所有可视化)-------------------------- if __name__ == "__main__": print("🚀 开始生成可视化图表...") try: # 加载模型和数据 model, X_test, y_test = load_model_and_data() # 生成3张图表(特征重要性 + 混淆矩阵 + ROC曲线) plot_feature_importance(model) plot_confusion_matrix(model, X_test, y_test) plot_roc_curve(model, X_test, y_test) print("\n🎉 所有图表生成完成!文件保存在项目根目录:") print("1. 特征重要性TOP10.png(业务洞察核心)") print("2. 混淆矩阵.png(模型效果直观展示)") print("3. ROC曲线.png(进阶评估,AUC值)") except Exception as e: print(f"\n❌ 生成失败:{str(e)}")