generated from Python-2026Spring/assignment-05-final-project-template
feat: 完成项目全流程配置(含Streamlit演示、uv依赖、.gitignore) #1
24
.gitignore
vendored
24
.gitignore
vendored
@ -1,4 +1,26 @@
|
|||||||
|
# ===== 环境变量(绝对不能提交!)=====
|
||||||
|
.env
|
||||||
|
|
||||||
|
# ===== Python 虚拟环境 =====
|
||||||
.venv/
|
.venv/
|
||||||
|
venv/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.pyc
|
*.pyc
|
||||||
results/
|
*.pyo
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# ===== IDE 配置 =====
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
|
||||||
|
# ===== macOS 系统文件 =====
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# ===== Jupyter =====
|
||||||
|
.ipynb_checkpoints/
|
||||||
|
|
||||||
|
# ===== 超大文件(超过 10MB 需手动添加)=====
|
||||||
|
# 如果你的数据或模型文件超过 10MB,请在下面添加:
|
||||||
|
# data/large_dataset.csv
|
||||||
|
# models/large_model.pkl
|
||||||
1
.python-version
Normal file
1
.python-version
Normal file
@ -0,0 +1 @@
|
|||||||
|
3.13
|
||||||
80
app.py
Normal file
80
app.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
# app.py - 电信客户流失预测Streamlit交互页面
|
||||||
|
import streamlit as st
|
||||||
|
import joblib
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
|
||||||
|
# 基础设置(中文显示)
|
||||||
|
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False
|
||||||
|
|
||||||
|
# 缓存模型
|
||||||
|
@st.cache_resource
|
||||||
|
def load_trained_model():
|
||||||
|
try:
|
||||||
|
model = joblib.load('telco_churn_model.pkl')
|
||||||
|
return model
|
||||||
|
except FileNotFoundError:
|
||||||
|
st.error("❌ 未找到模型文件!请先运行 src/model.py 训练模型")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 页面标题
|
||||||
|
st.set_page_config(page_title="电信客户流失预测系统", page_icon="📱", layout="wide")
|
||||||
|
st.title("📱 电信客户流失预测系统")
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
# 1. 客户信息输入
|
||||||
|
st.subheader("1. 输入客户信息")
|
||||||
|
with st.form("customer_form"):
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
with col1:
|
||||||
|
tenure = st.number_input("在网时长(月)", min_value=0, max_value=72, value=12)
|
||||||
|
monthly_charges = st.number_input("月消费金额", min_value=18.0, max_value=120.0, value=69.99)
|
||||||
|
total_charges = st.number_input("总消费金额", min_value=18.0, max_value=8684.0, value=839.88)
|
||||||
|
with col2:
|
||||||
|
contract = st.selectbox("合约类型", ["Month-to-month(月付)", "One year(1年)", "Two year(2年)"])
|
||||||
|
internet_service = st.selectbox("网络服务", ["DSL(宽带)", "Fiber optic(光纤)", "No(无)"])
|
||||||
|
payment_method = st.selectbox("支付方式", ["Electronic check(电子支票)", "Mailed check(邮寄支票)", "Bank transfer(银行转账)", "Credit card(信用卡)"])
|
||||||
|
submit_btn = st.form_submit_button("🚀 预测流失风险")
|
||||||
|
|
||||||
|
# 2. 预测结果
|
||||||
|
if submit_btn:
|
||||||
|
model = load_trained_model()
|
||||||
|
if model:
|
||||||
|
# 转换输入为模型可识别格式
|
||||||
|
contract_map = {"Month-to-month(月付)": "Month-to-month", "One year(1年)": "One year", "Two year(2年)": "Two year"}
|
||||||
|
internet_map = {"DSL(宽带)": "DSL", "Fiber optic(光纤)": "Fiber optic", "No(无)": "No"}
|
||||||
|
payment_map = {"Electronic check(电子支票)": "Electronic check", "Mailed check(邮寄支票)": "Mailed check", "Bank transfer(银行转账)": "Bank transfer (automatic)", "Credit card(信用卡)": "Credit card (automatic)"}
|
||||||
|
|
||||||
|
# 构造数据
|
||||||
|
customer_data = pd.DataFrame({
|
||||||
|
'customerID': ['TEST001'], 'tenure': [tenure], 'MonthlyCharges': [monthly_charges], 'TotalCharges': [total_charges],
|
||||||
|
'Contract': [contract_map[contract]], 'InternetService': [internet_map[internet_service]], 'PaymentMethod': [payment_map[payment_method]],
|
||||||
|
'OnlineSecurity': ['No'], 'TechSupport': ['No'], 'PaperlessBilling': ['Yes'], 'gender': ['Female'],
|
||||||
|
'SeniorCitizen': [0], 'Partner': ['Yes'], 'Dependents': ['No'], 'PhoneService': ['Yes'],
|
||||||
|
'MultipleLines': ['No'], 'OnlineBackup': ['Yes'], 'DeviceProtection': ['No'], 'StreamingTV': ['Yes'], 'StreamingMovies': ['No']
|
||||||
|
})
|
||||||
|
# 添加特征组合
|
||||||
|
customer_data['tenure_monthly'] = customer_data['tenure'] * customer_data['MonthlyCharges']
|
||||||
|
customer_data['tenure_ratio'] = customer_data['tenure'] / (customer_data['TotalCharges'] + 1)
|
||||||
|
customer_data['monthly_total_ratio'] = customer_data['MonthlyCharges'] / (customer_data['TotalCharges'] + 1)
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
churn_pred = model.predict(customer_data)[0]
|
||||||
|
churn_prob = model.predict_proba(customer_data)[:, 1][0]
|
||||||
|
|
||||||
|
# 展示结果
|
||||||
|
st.divider()
|
||||||
|
st.subheader("2. 流失风险预测结果")
|
||||||
|
if churn_pred == 1:
|
||||||
|
st.error(f"⚠️ 流失风险高(概率:{churn_prob:.2%})")
|
||||||
|
else:
|
||||||
|
st.success(f"✅ 流失风险低(概率:{churn_prob:.2%})")
|
||||||
|
|
||||||
|
# 3. 分析与建议
|
||||||
|
st.divider()
|
||||||
|
st.subheader("3. 流失影响因素")
|
||||||
|
st.image("特征重要性TOP10.png", caption="影响流失的TOP10因素")
|
||||||
|
st.subheader("4. 留存建议")
|
||||||
|
st.markdown("- 短在网时长+高消费:提供长期合约折扣\n- 光纤用户:优化网络稳定性\n- 电子支票支付:自动续费优惠")
|
||||||
|
Can't render this file because it is too large.
|
6
main.py
Normal file
6
main.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
def main():
|
||||||
|
print("Hello from telco-churn-analysis!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
0
models/新建 文本文档.txt
Normal file
0
models/新建 文本文档.txt
Normal file
68
predict.py
Normal file
68
predict.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
# predict.py - 客户流失预测接口
|
||||||
|
import joblib
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
# 加载保存好的模型(关键:路径要和save_model.py生成的pkl文件一致)
|
||||||
|
try:
|
||||||
|
model = joblib.load('telco_churn_model.pkl')
|
||||||
|
print("✅ 模型加载成功!")
|
||||||
|
except FileNotFoundError:
|
||||||
|
print("❌ 未找到模型文件,请先运行save_model.py生成telco_churn_model.pkl")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
def predict_churn(customer_data):
|
||||||
|
"""
|
||||||
|
预测客户流失概率
|
||||||
|
参数:customer_data - DataFrame,包含客户的所有特征
|
||||||
|
返回:预测结果(是否流失+流失概率)
|
||||||
|
"""
|
||||||
|
# 新增特征组合(必须和训练时一致,否则预测不准)
|
||||||
|
customer_data['tenure_monthly'] = customer_data['tenure'] * customer_data['MonthlyCharges']
|
||||||
|
customer_data['tenure_ratio'] = customer_data['tenure'] / (customer_data['TotalCharges'] + 1)
|
||||||
|
customer_data['monthly_total_ratio'] = customer_data['MonthlyCharges'] / (customer_data['TotalCharges'] + 1)
|
||||||
|
|
||||||
|
# 执行预测
|
||||||
|
churn_pred = model.predict(customer_data) # 0=未流失,1=流失
|
||||||
|
churn_prob = model.predict_proba(customer_data)[:, 1] # 流失概率
|
||||||
|
|
||||||
|
# 整理结果
|
||||||
|
result = pd.DataFrame({
|
||||||
|
'客户ID': customer_data['customerID'],
|
||||||
|
'是否流失(0=否,1=是)': churn_pred,
|
||||||
|
'流失概率': churn_prob.round(3) # 保留3位小数
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 示例:预测1个测试客户的流失概率
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 构造测试客户数据(特征和原数据集一致)
|
||||||
|
test_customer = pd.DataFrame({
|
||||||
|
'customerID': ['TEST001'], # 客户ID
|
||||||
|
'tenure': [12], # 在网时长(月)
|
||||||
|
'MonthlyCharges': [69.99], # 月消费
|
||||||
|
'TotalCharges': [839.88], # 总消费
|
||||||
|
'Contract': ['Month-to-month'], # 合同类型(月付)
|
||||||
|
'InternetService': ['Fiber optic'], # 光纤上网
|
||||||
|
'PaymentMethod': ['Electronic check'], # 电子支票支付
|
||||||
|
'OnlineSecurity': ['No'], # 无网络安全服务
|
||||||
|
'TechSupport': ['No'], # 无技术支持
|
||||||
|
'PaperlessBilling': ['Yes'], # 无纸化账单
|
||||||
|
# 以下是其他必填特征(按原数据集补充)
|
||||||
|
'gender': ['Female'],
|
||||||
|
'SeniorCitizen': [0],
|
||||||
|
'Partner': ['Yes'],
|
||||||
|
'Dependents': ['No'],
|
||||||
|
'PhoneService': ['Yes'],
|
||||||
|
'MultipleLines': ['No'],
|
||||||
|
'OnlineBackup': ['Yes'],
|
||||||
|
'DeviceProtection': ['No'],
|
||||||
|
'StreamingTV': ['Yes'],
|
||||||
|
'StreamingMovies': ['No']
|
||||||
|
})
|
||||||
|
|
||||||
|
# 调用预测函数
|
||||||
|
result = predict_churn(test_customer)
|
||||||
|
|
||||||
|
# 打印结果
|
||||||
|
print("\n📊 客户流失预测结果:")
|
||||||
|
print(result)
|
||||||
15
pyproject.toml
Normal file
15
pyproject.toml
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
[project]
|
||||||
|
name = "telco-churn-analysis"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.13"
|
||||||
|
dependencies = [
|
||||||
|
"joblib>=1.5.3",
|
||||||
|
"matplotlib>=3.10.8",
|
||||||
|
"pandas>=2.3.3",
|
||||||
|
"polars>=1.37.0",
|
||||||
|
"scikit-learn>=1.8.0",
|
||||||
|
"seaborn>=0.13.2",
|
||||||
|
"streamlit>=1.52.2",
|
||||||
|
]
|
||||||
8
run.bat
Normal file
8
run.bat
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
@echo off
|
||||||
|
echo 🚀 正在同步项目依赖...
|
||||||
|
uv sync
|
||||||
|
|
||||||
|
echo 🚀 启动电信客户流失预测系统...
|
||||||
|
streamlit run app.py
|
||||||
|
|
||||||
|
pause
|
||||||
31
save_model.py
Normal file
31
save_model.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
# save_model.py
|
||||||
|
import joblib
|
||||||
|
from src.data import load_data, split_data
|
||||||
|
from src.model import build_preprocessor
|
||||||
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
from sklearn.pipeline import Pipeline
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 加载数据
|
||||||
|
df = load_data()
|
||||||
|
X_train, X_test, y_train, y_test = split_data(df)
|
||||||
|
preprocessor = build_preprocessor(X_train)
|
||||||
|
|
||||||
|
# 训练调优后的随机森林模型
|
||||||
|
model = Pipeline(steps=[
|
||||||
|
('preprocessor', preprocessor),
|
||||||
|
('classifier', RandomForestClassifier(
|
||||||
|
n_estimators=200,
|
||||||
|
max_depth=15,
|
||||||
|
min_samples_split=8,
|
||||||
|
min_samples_leaf=4,
|
||||||
|
class_weight='balanced_subsample',
|
||||||
|
random_state=42,
|
||||||
|
n_jobs=-1
|
||||||
|
))
|
||||||
|
])
|
||||||
|
model.fit(X_train, y_train)
|
||||||
|
|
||||||
|
# 保存模型
|
||||||
|
joblib.dump(model, 'telco_churn_model.pkl')
|
||||||
|
print("✅ 模型已保存为:telco_churn_model.pkl")
|
||||||
94
src/data.py
94
src/data.py
@ -1,54 +1,50 @@
|
|||||||
"""电信客户流失数据集加载与清洗(最终可运行版)"""
|
# 改用Pandas加载数据,彻底解决Polars类型冲突
|
||||||
import polars as pl
|
import pandas as pd
|
||||||
import os # 用于检查文件是否存在
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
def load_telco_data():
|
# 正确的数据集路径(前加r)
|
||||||
"""加载并清洗电信流失数据集,绝对路径+完整容错"""
|
file_path = r"C:\Users\s1313\Desktop\telco_churn_analysis\data\telco_churn.csv"
|
||||||
# ========== 1. 正确的绝对路径(二选一即可) ==========
|
|
||||||
# 方式1:双反斜杠(推荐)
|
|
||||||
data_path = "C:\\Users\\s1313\\Desktop\\telco_churn_analysis\\data\\WA_Fn-UseC_-Telco-Customer-Churn.csv"
|
|
||||||
# 方式2:原始字符串(注释掉方式1,解开下面注释也可以)
|
|
||||||
# data_path = r"C:\Users\s1313\Desktop\telco_churn_analysis\data\WA_Fn-UseC_-Telco-Customer-Churn.csv"
|
|
||||||
|
|
||||||
# ========== 2. 检查文件是否存在(关键) ==========
|
def load_data():
|
||||||
if not os.path.exists(data_path):
|
"""用Pandas加载并清洗数据,稳定无报错"""
|
||||||
print(f"\n❌ 错误:文件不存在!")
|
# 加载数据,处理TotalCharges的混合类型问题
|
||||||
print(f"👉 请检查路径是否正确:{data_path}")
|
df = pd.read_csv(
|
||||||
print(f"👉 确认文件是否在这个位置,且文件名没有写错")
|
file_path,
|
||||||
return None # 避免程序崩溃
|
dtype={"TotalCharges": str} # 先读为字符串
|
||||||
|
)
|
||||||
|
|
||||||
|
# 清洗TotalCharges:空/空格转0,再转数值
|
||||||
|
df["TotalCharges"] = df["TotalCharges"].str.replace(" ", "").replace("", "0")
|
||||||
|
df["TotalCharges"] = pd.to_numeric(df["TotalCharges"], errors="coerce").fillna(0)
|
||||||
|
|
||||||
|
# 基础空值填充
|
||||||
|
df = df.fillna(0)
|
||||||
|
|
||||||
|
print("? 数据集加载并清洗完成!")
|
||||||
|
print(f"?? 数据规模:{df.shape[0]}行 × {df.shape[1]}列")
|
||||||
|
return df
|
||||||
|
|
||||||
# ========== 3. 读取并清洗数据 ==========
|
def split_data(df):
|
||||||
try:
|
"""分层抽样切分数据集"""
|
||||||
df = pl.read_csv(data_path)
|
# 分离特征和标签
|
||||||
|
X = df.drop(["customerID", "Churn"], axis=1)
|
||||||
# 安全清洗TotalCharges字段
|
y = df["Churn"].map({"Yes": 1, "No": 0})
|
||||||
if df["TotalCharges"].dtype == pl.Utf8:
|
|
||||||
df = df.with_columns(
|
# 切分数据(分层抽样,保证正负样本比例)
|
||||||
pl.col("TotalCharges")
|
X_train, X_test, y_train, y_test = train_test_split(
|
||||||
.str.replace(" ", "0")
|
X, y, test_size=0.2, random_state=42, stratify=y
|
||||||
.cast(pl.Float64, strict=False)
|
)
|
||||||
.fill_null(0.0)
|
|
||||||
.alias("TotalCharges")
|
print(f"\n?? 数据集切分完成:")
|
||||||
)
|
print(f" - 训练集:{len(X_train)}条,流失率:{y_train.mean():.2%}")
|
||||||
else:
|
print(f" - 测试集:{len(X_test)}条,流失率:{y_test.mean():.2%}")
|
||||||
df = df.with_columns(
|
return X_train, X_test, y_train, y_test
|
||||||
pl.col("TotalCharges")
|
|
||||||
.fill_null(0.0)
|
|
||||||
.alias("TotalCharges")
|
|
||||||
)
|
|
||||||
|
|
||||||
# ========== 4. 输出成功结果 ==========
|
|
||||||
print("\n✅ 数据集加载并清洗完成!")
|
|
||||||
print(f"📊 数据规模:{df.shape[0]}行 × {df.shape[1]}列")
|
|
||||||
print(f"📈 TotalCharges字段类型:{df['TotalCharges'].dtype}")
|
|
||||||
print("🔍 前2行预览:")
|
|
||||||
print(df.head(2))
|
|
||||||
return df
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ 数据处理出错:{type(e).__name__} → {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 测试入口
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
load_telco_data()
|
try:
|
||||||
|
df = load_data()
|
||||||
|
X_train, X_test, y_train, y_test = split_data(df)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"? 运行出错:{e}")
|
||||||
|
print(f"当前路径:{file_path}")
|
||||||
|
print("请检查:1. data文件夹存在 2. 数据集文件命名为telco_churn.csv")
|
||||||
110
src/model.py
Normal file
110
src/model.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
# model.py - 客户流失预测模型(随机森林优化版)
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.preprocessing import OneHotEncoder, StandardScaler
|
||||||
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
from sklearn.metrics import (
|
||||||
|
accuracy_score, classification_report,
|
||||||
|
roc_auc_score, confusion_matrix
|
||||||
|
)
|
||||||
|
from sklearn.compose import ColumnTransformer
|
||||||
|
from sklearn.pipeline import Pipeline
|
||||||
|
# 导入你已稳定的data.py中的函数
|
||||||
|
from src.data import load_data, split_data
|
||||||
|
|
||||||
|
def build_preprocessor(X):
|
||||||
|
"""构建特征预处理管道:区分数值/类别特征,分别处理"""
|
||||||
|
# 分离数值特征和类别特征列名
|
||||||
|
numeric_cols = X.select_dtypes(include=['int64', 'float64']).columns.tolist()
|
||||||
|
categorical_cols = X.select_dtypes(include=['object']).columns.tolist()
|
||||||
|
|
||||||
|
# 定义数值特征处理器(标准化)
|
||||||
|
numeric_transformer = Pipeline(steps=[
|
||||||
|
('scaler', StandardScaler())
|
||||||
|
])
|
||||||
|
|
||||||
|
# 定义类别特征处理器(独热编码)
|
||||||
|
categorical_transformer = Pipeline(steps=[
|
||||||
|
('onehot', OneHotEncoder(sparse_output=False, drop='first', handle_unknown='ignore'))
|
||||||
|
])
|
||||||
|
|
||||||
|
# 合并处理器
|
||||||
|
preprocessor = ColumnTransformer(
|
||||||
|
transformers=[
|
||||||
|
('num', numeric_transformer, numeric_cols),
|
||||||
|
('cat', categorical_transformer, categorical_cols)
|
||||||
|
])
|
||||||
|
|
||||||
|
return preprocessor
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# ========== 1. 加载并切分数据 ==========
|
||||||
|
print("📌 开始加载数据集...")
|
||||||
|
df = load_data()
|
||||||
|
X_train, X_test, y_train, y_test = split_data(df)
|
||||||
|
print("✅ 数据加载&切分完成")
|
||||||
|
|
||||||
|
# ========== 2. 构建预处理+模型管道 ==========
|
||||||
|
print("\n📌 构建预处理管道和随机森林模型...")
|
||||||
|
preprocessor = build_preprocessor(X_train)
|
||||||
|
|
||||||
|
# 随机森林模型(针对样本不平衡优化)
|
||||||
|
model = Pipeline(steps=[
|
||||||
|
('preprocessor', preprocessor),
|
||||||
|
('classifier', RandomForestClassifier(
|
||||||
|
n_estimators=200, # 决策树数量从100→200(提升模型稳定性)
|
||||||
|
max_depth=15, # 树深度从10→15(捕捉更多特征规律)
|
||||||
|
min_samples_split=8, # 新增:分裂节点最少8个样本(减少过拟合)
|
||||||
|
min_samples_leaf=4, # 新增:叶子节点最少4个样本(过滤噪声)
|
||||||
|
class_weight='balanced_subsample', # 从balanced→balanced_subsample(更精准平衡)
|
||||||
|
random_state=42,
|
||||||
|
n_jobs=-1 # 保留:多核加速
|
||||||
|
))
|
||||||
|
])
|
||||||
|
|
||||||
|
# ========== 3. 模型训练 ==========
|
||||||
|
print("📌 开始训练随机森林模型...")
|
||||||
|
model.fit(X_train, y_train)
|
||||||
|
print("✅ 模型训练完成")
|
||||||
|
|
||||||
|
# ========== 4. 模型评估 ==========
|
||||||
|
print("\n📊 模型评估结果:")
|
||||||
|
# 预测概率和类别
|
||||||
|
y_pred = model.predict(X_test)
|
||||||
|
y_pred_proba = model.predict_proba(X_test)[:, 1]
|
||||||
|
|
||||||
|
# 核心评估指标
|
||||||
|
print(f"🔹 准确率(Accuracy):{accuracy_score(y_test, y_pred):.4f}")
|
||||||
|
print(f"🔹 AUC值:{roc_auc_score(y_test, y_pred_proba):.4f}")
|
||||||
|
print("\n🔹 分类详细报告:")
|
||||||
|
print(classification_report(y_test, y_pred, target_names=['未流失', '流失']))
|
||||||
|
|
||||||
|
# 混淆矩阵
|
||||||
|
print("\n🔹 混淆矩阵:")
|
||||||
|
cm = confusion_matrix(y_test, y_pred)
|
||||||
|
cm_df = pd.DataFrame(
|
||||||
|
cm,
|
||||||
|
index=['实际未流失', '实际流失'],
|
||||||
|
columns=['预测未流失', '预测流失']
|
||||||
|
)
|
||||||
|
print(cm_df)
|
||||||
|
|
||||||
|
# ========== 5. 特征重要性分析 ==========
|
||||||
|
print("\n📈 特征重要性TOP10:")
|
||||||
|
# 获取特征名
|
||||||
|
preprocessor.fit(X_train)
|
||||||
|
feature_names = []
|
||||||
|
# 数值特征名
|
||||||
|
feature_names.extend(preprocessor.transformers_[0][2])
|
||||||
|
# 类别特征编码后的名字
|
||||||
|
cat_features = preprocessor.named_transformers_['cat'].named_steps['onehot']
|
||||||
|
feature_names.extend(cat_features.get_feature_names_out(preprocessor.transformers_[1][2]))
|
||||||
|
|
||||||
|
# 获取特征重要性
|
||||||
|
importances = model.named_steps['classifier'].feature_importances_
|
||||||
|
# 排序并输出TOP10
|
||||||
|
importance_df = pd.DataFrame({
|
||||||
|
'特征名': feature_names,
|
||||||
|
'重要性': importances
|
||||||
|
}).sort_values('重要性', ascending=False).head(10)
|
||||||
|
print(importance_df)
|
||||||
BIN
telco_churn_model.pkl
Normal file
BIN
telco_churn_model.pkl
Normal file
Binary file not shown.
165
visualization.py
Normal file
165
visualization.py
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
# visualization.py - 客户流失预测模型可视化(直接运行即可)
|
||||||
|
import joblib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
import pandas as pd
|
||||||
|
from sklearn.metrics import confusion_matrix, roc_curve, auc
|
||||||
|
|
||||||
|
# -------------------------- 基础设置(解决中文显示、图表样式)--------------------------
|
||||||
|
plt.rcParams['font.sans-serif'] = ['SimHei'] # 中文显示(Windows)
|
||||||
|
# 如果是Mac/Linux,替换为:plt.rcParams['font.sans-serif'] = ['Arial Unicode MS']
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
|
||||||
|
plt.style.use('seaborn-v0_8-whitegrid') # 图表样式(清爽易读)
|
||||||
|
|
||||||
|
# -------------------------- 1. 加载模型和数据(复用已有逻辑)--------------------------
|
||||||
|
def load_model_and_data():
|
||||||
|
"""加载训练好的模型和测试集数据"""
|
||||||
|
# 加载模型(确保模型文件路径正确)
|
||||||
|
try:
|
||||||
|
model = joblib.load('telco_churn_model.pkl')
|
||||||
|
print("✅ 模型加载成功")
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise FileNotFoundError("❌ 未找到模型文件!请先运行 src/model.py 训练模型")
|
||||||
|
|
||||||
|
# 加载并切分数据(复用 src/data.py 的逻辑,避免重复代码)
|
||||||
|
try:
|
||||||
|
from src.data import load_data, split_data
|
||||||
|
df = load_data()
|
||||||
|
X_train, X_test, y_train, y_test = split_data(df)
|
||||||
|
print("✅ 测试集数据加载成功(共1409条)")
|
||||||
|
return model, X_test, y_test
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("❌ 未找到 src/data.py!请确保项目目录结构正确")
|
||||||
|
|
||||||
|
# -------------------------- 2. 特征重要性TOP10可视化(核心业务洞察)--------------------------
|
||||||
|
def plot_feature_importance(model):
|
||||||
|
"""绘制特征重要性TOP10图表"""
|
||||||
|
# 提取预处理后的特征名和重要性得分
|
||||||
|
preprocessor = model.named_steps['preprocessor']
|
||||||
|
feature_names = preprocessor.get_feature_names_out()
|
||||||
|
feature_importance = model.named_steps['classifier'].feature_importances_
|
||||||
|
|
||||||
|
# 整理数据(排序+取TOP10,简化特征名方便显示)
|
||||||
|
feature_df = pd.DataFrame({
|
||||||
|
'特征名': feature_names,
|
||||||
|
'重要性': feature_importance
|
||||||
|
}).sort_values('重要性', ascending=False).head(10)
|
||||||
|
|
||||||
|
# 简化特征名(原特征名太长,图表显示优化)
|
||||||
|
feature_name_map = {
|
||||||
|
'tenure': '客户在网时长',
|
||||||
|
'TotalCharges': '总消费金额',
|
||||||
|
'MonthlyCharges': '月消费金额',
|
||||||
|
'Contract_Two year': '合约期-2年',
|
||||||
|
'InternetService_Fiber optic': '网络类型-光纤',
|
||||||
|
'PaymentMethod_Electronic check': '支付方式-电子支票',
|
||||||
|
'Contract_One year': '合约期-1年',
|
||||||
|
'OnlineSecurity_Yes': '在线安全服务-有',
|
||||||
|
'TechSupport_Yes': '技术支持-有',
|
||||||
|
'PaperlessBilling_Yes': '电子账单-是'
|
||||||
|
}
|
||||||
|
feature_df['简化特征名'] = feature_df['特征名'].map(lambda x: feature_name_map.get(x, x[:15])) # 兜底避免报错
|
||||||
|
|
||||||
|
# 绘制水平条形图(更易读)
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 6))
|
||||||
|
sns.barplot(
|
||||||
|
x='重要性', y='简化特征名', data=feature_df,
|
||||||
|
palette='viridis_r', ax=ax # 颜色渐变(反向,重要性越高颜色越深)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 图表美化(标题、标签、数值标注)
|
||||||
|
ax.set_title('客户流失预测 - 特征重要性TOP10', fontsize=16, fontweight='bold', pad=20)
|
||||||
|
ax.set_xlabel('重要性得分', fontsize=12)
|
||||||
|
ax.set_ylabel('特征', fontsize=12)
|
||||||
|
ax.tick_params(axis='y', labelsize=10)
|
||||||
|
|
||||||
|
# 在条形图上添加数值(直观展示得分)
|
||||||
|
for i, v in enumerate(feature_df['重要性']):
|
||||||
|
ax.text(v + 0.002, i, f'{v:.3f}', va='center', fontsize=9)
|
||||||
|
|
||||||
|
# 保存图表(高清,可直接插入PPT)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('特征重要性TOP10.png', dpi=300, bbox_inches='tight')
|
||||||
|
print("✅ 特征重要性图表已保存为:特征重要性TOP10.png")
|
||||||
|
|
||||||
|
# -------------------------- 3. 混淆矩阵可视化(模型效果直观展示)--------------------------
|
||||||
|
def plot_confusion_matrix(model, X_test, y_test):
|
||||||
|
"""绘制混淆矩阵(展示模型预测准确率、漏判/误判情况)"""
|
||||||
|
# 生成预测结果
|
||||||
|
y_pred = model.predict(X_test)
|
||||||
|
|
||||||
|
# 计算混淆矩阵
|
||||||
|
cm = confusion_matrix(y_test, y_pred)
|
||||||
|
# 混淆矩阵标签(0=未流失,1=流失)
|
||||||
|
labels = ['未流失', '流失']
|
||||||
|
|
||||||
|
# 绘制热力图
|
||||||
|
fig, ax = plt.subplots(figsize=(8, 6))
|
||||||
|
sns.heatmap(
|
||||||
|
cm, annot=True, fmt='d', cmap='Blues', # fmt='d' 显示整数
|
||||||
|
xticklabels=labels, yticklabels=labels, ax=ax,
|
||||||
|
cbar_kws={'label': '客户数量'} # 颜色条标签
|
||||||
|
)
|
||||||
|
|
||||||
|
# 图表美化
|
||||||
|
ax.set_title('客户流失预测 - 混淆矩阵', fontsize=16, fontweight='bold', pad=20)
|
||||||
|
ax.set_xlabel('预测标签', fontsize=12)
|
||||||
|
ax.set_ylabel('真实标签', fontsize=12)
|
||||||
|
|
||||||
|
# 添加统计信息(准确率、流失识别率)
|
||||||
|
total = cm.sum()
|
||||||
|
accuracy = (cm[0,0] + cm[1,1]) / total
|
||||||
|
recall_churn = cm[1,1] / (cm[1,0] + cm[1,1]) # 流失客户识别率
|
||||||
|
ax.text(0.5, -0.15, f'准确率:{accuracy:.3f} | 流失识别率:{recall_churn:.3f}',
|
||||||
|
ha='center', transform=ax.transAxes, fontsize=11)
|
||||||
|
|
||||||
|
# 保存图表
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('混淆矩阵.png', dpi=300, bbox_inches='tight')
|
||||||
|
print("✅ 混淆矩阵图表已保存为:混淆矩阵.png")
|
||||||
|
|
||||||
|
# -------------------------- 4. 可选:ROC曲线可视化(进阶模型评估)--------------------------
|
||||||
|
def plot_roc_curve(model, X_test, y_test):
|
||||||
|
"""绘制ROC曲线(展示模型区分能力,AUC值)"""
|
||||||
|
# 生成预测概率
|
||||||
|
y_pred_proba = model.predict_proba(X_test)[:, 1] # 取流失(1类)的概率
|
||||||
|
fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
|
||||||
|
roc_auc = auc(fpr, tpr)
|
||||||
|
|
||||||
|
# 绘制ROC曲线
|
||||||
|
fig, ax = plt.subplots(figsize=(8, 6))
|
||||||
|
ax.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC曲线 (AUC = {roc_auc:.3f})')
|
||||||
|
ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='随机猜测')
|
||||||
|
|
||||||
|
# 图表美化
|
||||||
|
ax.set_xlim([0.0, 1.0])
|
||||||
|
ax.set_ylim([0.0, 1.05])
|
||||||
|
ax.set_title('客户流失预测 - ROC曲线', fontsize=16, fontweight='bold', pad=20)
|
||||||
|
ax.set_xlabel('假阳性率(误判为流失)', fontsize=12)
|
||||||
|
ax.set_ylabel('真阳性率(正确识别流失)', fontsize=12)
|
||||||
|
ax.legend(loc="lower right", fontsize=11)
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# 保存图表
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('ROC曲线.png', dpi=300, bbox_inches='tight')
|
||||||
|
print("✅ ROC曲线图表已保存为:ROC曲线.png")
|
||||||
|
|
||||||
|
# -------------------------- 主函数(一键运行所有可视化)--------------------------
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("🚀 开始生成可视化图表...")
|
||||||
|
try:
|
||||||
|
# 加载模型和数据
|
||||||
|
model, X_test, y_test = load_model_and_data()
|
||||||
|
|
||||||
|
# 生成3张图表(特征重要性 + 混淆矩阵 + ROC曲线)
|
||||||
|
plot_feature_importance(model)
|
||||||
|
plot_confusion_matrix(model, X_test, y_test)
|
||||||
|
plot_roc_curve(model, X_test, y_test)
|
||||||
|
|
||||||
|
print("\n🎉 所有图表生成完成!文件保存在项目根目录:")
|
||||||
|
print("1. 特征重要性TOP10.png(业务洞察核心)")
|
||||||
|
print("2. 混淆矩阵.png(模型效果直观展示)")
|
||||||
|
print("3. ROC曲线.png(进阶评估,AUC值)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ 生成失败:{str(e)}")
|
||||||
BIN
特征重要性TOP10.png
Normal file
BIN
特征重要性TOP10.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 120 KiB |
Loading…
Reference in New Issue
Block a user