telco-customer-churn-predic.../visualization.py

165 lines
7.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# visualization.py - 客户流失预测模型可视化(直接运行即可)
import joblib
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.metrics import confusion_matrix, roc_curve, auc
# -------------------------- 基础设置(解决中文显示、图表样式)--------------------------
plt.rcParams['font.sans-serif'] = ['SimHei'] # 中文显示Windows
# 如果是Mac/Linux替换为plt.rcParams['font.sans-serif'] = ['Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
plt.style.use('seaborn-v0_8-whitegrid') # 图表样式(清爽易读)
# -------------------------- 1. 加载模型和数据(复用已有逻辑)--------------------------
def load_model_and_data():
"""加载训练好的模型和测试集数据"""
# 加载模型(确保模型文件路径正确)
try:
model = joblib.load('telco_churn_model.pkl')
print("✅ 模型加载成功")
except FileNotFoundError:
raise FileNotFoundError("❌ 未找到模型文件!请先运行 src/model.py 训练模型")
# 加载并切分数据(复用 src/data.py 的逻辑,避免重复代码)
try:
from src.data import load_data, split_data
df = load_data()
X_train, X_test, y_train, y_test = split_data(df)
print("✅ 测试集数据加载成功共1409条")
return model, X_test, y_test
except ImportError:
raise ImportError("❌ 未找到 src/data.py请确保项目目录结构正确")
# -------------------------- 2. 特征重要性TOP10可视化核心业务洞察--------------------------
def plot_feature_importance(model):
"""绘制特征重要性TOP10图表"""
# 提取预处理后的特征名和重要性得分
preprocessor = model.named_steps['preprocessor']
feature_names = preprocessor.get_feature_names_out()
feature_importance = model.named_steps['classifier'].feature_importances_
# 整理数据(排序+取TOP10简化特征名方便显示
feature_df = pd.DataFrame({
'特征名': feature_names,
'重要性': feature_importance
}).sort_values('重要性', ascending=False).head(10)
# 简化特征名(原特征名太长,图表显示优化)
feature_name_map = {
'tenure': '客户在网时长',
'TotalCharges': '总消费金额',
'MonthlyCharges': '月消费金额',
'Contract_Two year': '合约期-2年',
'InternetService_Fiber optic': '网络类型-光纤',
'PaymentMethod_Electronic check': '支付方式-电子支票',
'Contract_One year': '合约期-1年',
'OnlineSecurity_Yes': '在线安全服务-有',
'TechSupport_Yes': '技术支持-有',
'PaperlessBilling_Yes': '电子账单-是'
}
feature_df['简化特征名'] = feature_df['特征名'].map(lambda x: feature_name_map.get(x, x[:15])) # 兜底避免报错
# 绘制水平条形图(更易读)
fig, ax = plt.subplots(figsize=(10, 6))
sns.barplot(
x='重要性', y='简化特征名', data=feature_df,
palette='viridis_r', ax=ax # 颜色渐变(反向,重要性越高颜色越深)
)
# 图表美化(标题、标签、数值标注)
ax.set_title('客户流失预测 - 特征重要性TOP10', fontsize=16, fontweight='bold', pad=20)
ax.set_xlabel('重要性得分', fontsize=12)
ax.set_ylabel('特征', fontsize=12)
ax.tick_params(axis='y', labelsize=10)
# 在条形图上添加数值(直观展示得分)
for i, v in enumerate(feature_df['重要性']):
ax.text(v + 0.002, i, f'{v:.3f}', va='center', fontsize=9)
# 保存图表高清可直接插入PPT
plt.tight_layout()
plt.savefig('特征重要性TOP10.png', dpi=300, bbox_inches='tight')
print("✅ 特征重要性图表已保存为特征重要性TOP10.png")
# -------------------------- 3. 混淆矩阵可视化(模型效果直观展示)--------------------------
def plot_confusion_matrix(model, X_test, y_test):
"""绘制混淆矩阵(展示模型预测准确率、漏判/误判情况)"""
# 生成预测结果
y_pred = model.predict(X_test)
# 计算混淆矩阵
cm = confusion_matrix(y_test, y_pred)
# 混淆矩阵标签0=未流失1=流失)
labels = ['未流失', '流失']
# 绘制热力图
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(
cm, annot=True, fmt='d', cmap='Blues', # fmt='d' 显示整数
xticklabels=labels, yticklabels=labels, ax=ax,
cbar_kws={'label': '客户数量'} # 颜色条标签
)
# 图表美化
ax.set_title('客户流失预测 - 混淆矩阵', fontsize=16, fontweight='bold', pad=20)
ax.set_xlabel('预测标签', fontsize=12)
ax.set_ylabel('真实标签', fontsize=12)
# 添加统计信息(准确率、流失识别率)
total = cm.sum()
accuracy = (cm[0,0] + cm[1,1]) / total
recall_churn = cm[1,1] / (cm[1,0] + cm[1,1]) # 流失客户识别率
ax.text(0.5, -0.15, f'准确率:{accuracy:.3f} | 流失识别率:{recall_churn:.3f}',
ha='center', transform=ax.transAxes, fontsize=11)
# 保存图表
plt.tight_layout()
plt.savefig('混淆矩阵.png', dpi=300, bbox_inches='tight')
print("✅ 混淆矩阵图表已保存为:混淆矩阵.png")
# -------------------------- 4. 可选ROC曲线可视化进阶模型评估--------------------------
def plot_roc_curve(model, X_test, y_test):
"""绘制ROC曲线展示模型区分能力AUC值"""
# 生成预测概率
y_pred_proba = model.predict_proba(X_test)[:, 1] # 取流失1类的概率
fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
roc_auc = auc(fpr, tpr)
# 绘制ROC曲线
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC曲线 (AUC = {roc_auc:.3f})')
ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='随机猜测')
# 图表美化
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_title('客户流失预测 - ROC曲线', fontsize=16, fontweight='bold', pad=20)
ax.set_xlabel('假阳性率(误判为流失)', fontsize=12)
ax.set_ylabel('真阳性率(正确识别流失)', fontsize=12)
ax.legend(loc="lower right", fontsize=11)
ax.grid(True, alpha=0.3)
# 保存图表
plt.tight_layout()
plt.savefig('ROC曲线.png', dpi=300, bbox_inches='tight')
print("✅ ROC曲线图表已保存为ROC曲线.png")
# -------------------------- 主函数(一键运行所有可视化)--------------------------
if __name__ == "__main__":
print("🚀 开始生成可视化图表...")
try:
# 加载模型和数据
model, X_test, y_test = load_model_and_data()
# 生成3张图表特征重要性 + 混淆矩阵 + ROC曲线
plot_feature_importance(model)
plot_confusion_matrix(model, X_test, y_test)
plot_roc_curve(model, X_test, y_test)
print("\n🎉 所有图表生成完成!文件保存在项目根目录:")
print("1. 特征重要性TOP10.png业务洞察核心")
print("2. 混淆矩阵.png模型效果直观展示")
print("3. ROC曲线.png进阶评估AUC值")
except Exception as e:
print(f"\n❌ 生成失败:{str(e)}")