akaAKR47/agent.py
akr f47c7d7196 feat: 实现客户流失预测与行动建议闭环系统
添加完整的客户流失预测系统,包括数据处理、模型训练、预测和行动建议功能。主要包含以下模块:
1. 数据预处理流水线(Polars + Pandera)
2. 机器学习模型训练(LightGBM + Logistic Regression)
3. AI Agent预测和建议工具
4. Streamlit交互式Web界面
5. 完整的课程设计报告文档
2026-01-15 15:19:07 +08:00

200 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from pydantic import BaseModel, Field
from typing import Dict, Any, List
import polars as pl
import joblib
from data_processing import preprocess_data
# 定义输入输出模型
class CustomerData(BaseModel):
"""客户数据模型"""
gender: str = Field(..., description="性别: Male 或 Female")
SeniorCitizen: int = Field(..., description="是否为老年人: 0 或 1")
Partner: str = Field(..., description="是否有伴侣: Yes 或 No")
Dependents: str = Field(..., description="是否有家属: Yes 或 No")
tenure: int = Field(..., description="在网时长,单位为月")
PhoneService: str = Field(..., description="是否开通电话服务: Yes 或 No")
MultipleLines: str = Field(..., description="是否开通多条线路: Yes、No 或 No phone service")
InternetService: str = Field(..., description="网络服务类型: DSL、Fiber optic 或 No")
OnlineSecurity: str = Field(..., description="是否开通在线安全服务: Yes、No 或 No internet service")
OnlineBackup: str = Field(..., description="是否开通在线备份服务: Yes、No 或 No internet service")
DeviceProtection: str = Field(..., description="是否开通设备保护服务: Yes、No 或 No internet service")
TechSupport: str = Field(..., description="是否开通技术支持服务: Yes、No 或 No internet service")
StreamingTV: str = Field(..., description="是否开通流媒体电视服务: Yes、No 或 No internet service")
StreamingMovies: str = Field(..., description="是否开通流媒体电影服务: Yes、No 或 No internet service")
Contract: str = Field(..., description="合同类型: Month-to-month、One year 或 Two year")
PaperlessBilling: str = Field(..., description="是否使用无纸化账单: Yes 或 No")
PaymentMethod: str = Field(..., description="支付方式")
MonthlyCharges: float = Field(..., description="月费用")
TotalCharges: float = Field(..., description="总费用")
class ChurnPrediction(BaseModel):
"""客户流失预测结果"""
prediction: int = Field(..., description="预测结果: 0 表示不流失1 表示流失")
probability: float = Field(..., description="流失概率")
model_used: str = Field(..., description="使用的模型")
class ActionSuggestion(BaseModel):
"""基于预测结果的行动建议"""
customer_id: str = Field(..., description="客户ID")
prediction: int = Field(..., description="预测结果: 0 表示不流失1 表示流失")
probability: float = Field(..., description="流失概率")
suggestions: List[str] = Field(..., description="可执行的行动建议")
# Agent工具类
class ChurnPredictionAgent:
def __init__(self):
# 加载最佳模型使用LightGBM因为它通常表现更好
self.model = joblib.load("models/lightgbm_model.pkl")
self.model_name = "lightgbm"
# 工具1ML预测工具
def predict_churn(self, customer_data: CustomerData) -> ChurnPrediction:
"""
预测客户是否会流失
Args:
customer_data: 客户数据
Returns:
ChurnPrediction: 预测结果
"""
# 将客户数据转换为Polars DataFrame
customer_dict = customer_data.model_dump()
df = pl.DataFrame([customer_dict])
# 数据预处理(使用专门的单个客户预处理函数)
from data_processing import preprocess_single_customer
X_np = preprocess_single_customer(df)
# 预测
probability = self.model.predict_proba(X_np)[0, 1]
prediction = 1 if probability >= 0.5 else 0
return ChurnPrediction(
prediction=prediction,
probability=probability,
model_used=self.model_name
)
# 工具2行动建议工具
def get_action_suggestions(self, customer_id: str, prediction: int,
probability: float, customer_data: CustomerData) -> ActionSuggestion:
"""
基于预测结果给出可执行的行动建议
Args:
customer_id: 客户ID
prediction: 预测结果
probability: 流失概率
customer_data: 客户数据
Returns:
ActionSuggestion: 行动建议
"""
suggestions = []
if prediction == 1:
# 高流失风险客户
suggestions.append(f"客户 {customer_id}{probability:.2%} 的概率会流失,需要重点关注")
# 基于客户特征给出具体建议
if customer_data.Contract == "Month-to-month":
suggestions.append("建议提供长期合同折扣,鼓励客户转为一年或两年合同")
if customer_data.TechSupport == "No":
suggestions.append("建议提供免费的技术支持服务,提高客户满意度")
if customer_data.OnlineSecurity == "No":
suggestions.append("建议提供免费的在线安全服务,增加客户粘性")
if customer_data.tenure < 12:
suggestions.append("建议提供新客户忠诚度奖励计划,鼓励客户继续使用服务")
if customer_data.MonthlyCharges > 70:
suggestions.append(f"客户月费用较高 ({customer_data.MonthlyCharges} 元),建议提供费用优化方案")
else:
# 低流失风险客户
suggestions.append(f"客户 {customer_id} 流失风险较低 ({probability:.2%}),可维持现有服务")
# 基于客户特征给出具体建议
if customer_data.Contract == "Month-to-month":
suggestions.append("建议定期发送满意度调查,了解客户需求")
if customer_data.tenure >= 24:
suggestions.append("建议提供忠诚客户专属优惠,巩固客户关系")
return ActionSuggestion(
customer_id=customer_id,
prediction=prediction,
probability=probability,
suggestions=suggestions
)
# 工具3批量预测工具额外工具
def batch_predict(self, customer_data_list: List[CustomerData]) -> List[ChurnPrediction]:
"""
批量预测客户是否会流失
Args:
customer_data_list: 客户数据列表
Returns:
List[ChurnPrediction]: 预测结果列表
"""
results = []
for customer_data in customer_data_list:
result = self.predict_churn(customer_data)
results.append(result)
return results
# 测试Agent
if __name__ == "__main__":
# 创建Agent实例
agent = ChurnPredictionAgent()
# 测试数据
test_customer = CustomerData(
gender="Male",
SeniorCitizen=0,
Partner="Yes",
Dependents="No",
tenure=12,
PhoneService="Yes",
MultipleLines="No",
InternetService="Fiber optic",
OnlineSecurity="No",
OnlineBackup="Yes",
DeviceProtection="No",
TechSupport="No",
StreamingTV="Yes",
StreamingMovies="Yes",
Contract="Month-to-month",
PaperlessBilling="Yes",
PaymentMethod="Electronic check",
MonthlyCharges=79.85,
TotalCharges=977.6
)
# 1. 使用ML预测工具
print("=== 使用ML预测工具 ===")
prediction_result = agent.predict_churn(test_customer)
print(f"预测结果: {'会流失' if prediction_result.prediction == 1 else '不会流失'}")
print(f"流失概率: {prediction_result.probability:.2%}")
print(f"使用模型: {prediction_result.model_used}")
# 2. 使用行动建议工具
print("\n=== 使用行动建议工具 ===")
suggestions = agent.get_action_suggestions(
customer_id="TEST-123",
prediction=prediction_result.prediction,
probability=prediction_result.probability,
customer_data=test_customer
)
print(f"客户ID: {suggestions.customer_id}")
print(f"预测结果: {'会流失' if suggestions.prediction == 1 else '不会流失'}")
print(f"流失概率: {suggestions.probability:.2%}")
print("行动建议:")
for i, suggestion in enumerate(suggestions.suggestions, 1):
print(f" {i}. {suggestion}")