generated from Python-2026Spring/assignment-05-final-project-template
68 lines
2.6 KiB
Python
68 lines
2.6 KiB
Python
# predict.py - 客户流失预测接口
|
||
import joblib
|
||
import pandas as pd
|
||
|
||
# 加载保存好的模型(关键:路径要和save_model.py生成的pkl文件一致)
|
||
try:
|
||
model = joblib.load('telco_churn_model.pkl')
|
||
print("✅ 模型加载成功!")
|
||
except FileNotFoundError:
|
||
print("❌ 未找到模型文件,请先运行save_model.py生成telco_churn_model.pkl")
|
||
exit()
|
||
|
||
def predict_churn(customer_data):
|
||
"""
|
||
预测客户流失概率
|
||
参数:customer_data - DataFrame,包含客户的所有特征
|
||
返回:预测结果(是否流失+流失概率)
|
||
"""
|
||
# 新增特征组合(必须和训练时一致,否则预测不准)
|
||
customer_data['tenure_monthly'] = customer_data['tenure'] * customer_data['MonthlyCharges']
|
||
customer_data['tenure_ratio'] = customer_data['tenure'] / (customer_data['TotalCharges'] + 1)
|
||
customer_data['monthly_total_ratio'] = customer_data['MonthlyCharges'] / (customer_data['TotalCharges'] + 1)
|
||
|
||
# 执行预测
|
||
churn_pred = model.predict(customer_data) # 0=未流失,1=流失
|
||
churn_prob = model.predict_proba(customer_data)[:, 1] # 流失概率
|
||
|
||
# 整理结果
|
||
result = pd.DataFrame({
|
||
'客户ID': customer_data['customerID'],
|
||
'是否流失(0=否,1=是)': churn_pred,
|
||
'流失概率': churn_prob.round(3) # 保留3位小数
|
||
})
|
||
return result
|
||
|
||
# 示例:预测1个测试客户的流失概率
|
||
if __name__ == "__main__":
|
||
# 构造测试客户数据(特征和原数据集一致)
|
||
test_customer = pd.DataFrame({
|
||
'customerID': ['TEST001'], # 客户ID
|
||
'tenure': [12], # 在网时长(月)
|
||
'MonthlyCharges': [69.99], # 月消费
|
||
'TotalCharges': [839.88], # 总消费
|
||
'Contract': ['Month-to-month'], # 合同类型(月付)
|
||
'InternetService': ['Fiber optic'], # 光纤上网
|
||
'PaymentMethod': ['Electronic check'], # 电子支票支付
|
||
'OnlineSecurity': ['No'], # 无网络安全服务
|
||
'TechSupport': ['No'], # 无技术支持
|
||
'PaperlessBilling': ['Yes'], # 无纸化账单
|
||
# 以下是其他必填特征(按原数据集补充)
|
||
'gender': ['Female'],
|
||
'SeniorCitizen': [0],
|
||
'Partner': ['Yes'],
|
||
'Dependents': ['No'],
|
||
'PhoneService': ['Yes'],
|
||
'MultipleLines': ['No'],
|
||
'OnlineBackup': ['Yes'],
|
||
'DeviceProtection': ['No'],
|
||
'StreamingTV': ['Yes'],
|
||
'StreamingMovies': ['No']
|
||
})
|
||
|
||
# 调用预测函数
|
||
result = predict_churn(test_customer)
|
||
|
||
# 打印结果
|
||
print("\n📊 客户流失预测结果:")
|
||
print(result) |