telco-customer-churn-predic.../predict.py

68 lines
2.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# predict.py - 客户流失预测接口
import joblib
import pandas as pd
# 加载保存好的模型关键路径要和save_model.py生成的pkl文件一致
try:
model = joblib.load('telco_churn_model.pkl')
print("✅ 模型加载成功!")
except FileNotFoundError:
print("❌ 未找到模型文件请先运行save_model.py生成telco_churn_model.pkl")
exit()
def predict_churn(customer_data):
"""
预测客户流失概率
参数customer_data - DataFrame包含客户的所有特征
返回:预测结果(是否流失+流失概率)
"""
# 新增特征组合(必须和训练时一致,否则预测不准)
customer_data['tenure_monthly'] = customer_data['tenure'] * customer_data['MonthlyCharges']
customer_data['tenure_ratio'] = customer_data['tenure'] / (customer_data['TotalCharges'] + 1)
customer_data['monthly_total_ratio'] = customer_data['MonthlyCharges'] / (customer_data['TotalCharges'] + 1)
# 执行预测
churn_pred = model.predict(customer_data) # 0=未流失1=流失
churn_prob = model.predict_proba(customer_data)[:, 1] # 流失概率
# 整理结果
result = pd.DataFrame({
'客户ID': customer_data['customerID'],
'是否流失0=否1=是)': churn_pred,
'流失概率': churn_prob.round(3) # 保留3位小数
})
return result
# 示例预测1个测试客户的流失概率
if __name__ == "__main__":
# 构造测试客户数据(特征和原数据集一致)
test_customer = pd.DataFrame({
'customerID': ['TEST001'], # 客户ID
'tenure': [12], # 在网时长(月)
'MonthlyCharges': [69.99], # 月消费
'TotalCharges': [839.88], # 总消费
'Contract': ['Month-to-month'], # 合同类型(月付)
'InternetService': ['Fiber optic'], # 光纤上网
'PaymentMethod': ['Electronic check'], # 电子支票支付
'OnlineSecurity': ['No'], # 无网络安全服务
'TechSupport': ['No'], # 无技术支持
'PaperlessBilling': ['Yes'], # 无纸化账单
# 以下是其他必填特征(按原数据集补充)
'gender': ['Female'],
'SeniorCitizen': [0],
'Partner': ['Yes'],
'Dependents': ['No'],
'PhoneService': ['Yes'],
'MultipleLines': ['No'],
'OnlineBackup': ['Yes'],
'DeviceProtection': ['No'],
'StreamingTV': ['Yes'],
'StreamingMovies': ['No']
})
# 调用预测函数
result = predict_churn(test_customer)
# 打印结果
print("\n📊 客户流失预测结果:")
print(result)