G09-BankMarketing/train_model.py
2026-01-16 19:22:13 +08:00

91 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import numpy as np
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import accuracy_score, classification_report
import joblib
import json
# 1. 加载数据
print("正在加载数据...")
df = pd.read_csv('bank.csv')
# 2. 数据预处理
print("正在进行数据预处理...")
# 移除 duration 列 (避免数据泄露)
if 'duration' in df.columns:
df = df.drop('duration', axis=1)
# 分离特征和目标
X = df.drop('deposit', axis=1)
y = df['deposit']
# 处理目标变量 (yes -> 1, no -> 0)
le_target = LabelEncoder()
y = le_target.fit_transform(y)
# 识别分类特征和数值特征
categorical_cols = X.select_dtypes(include=['object']).columns.tolist()
numeric_cols = X.select_dtypes(include=['int64', 'float64']).columns.tolist()
# 保存列名信息,供 Agent 使用
feature_meta = {
'numeric_cols': numeric_cols,
'categorical_cols': categorical_cols,
'all_cols': list(X.columns)
}
# 对分类特征进行 Label Encoding
# 注意XGBoost 可以处理类别特征,但通常需要转换为数值。
# 为了简化 Agent 的推理流程,我们需要保存这些 Encoder。
encoders = {}
for col in categorical_cols:
le = LabelEncoder()
X[col] = le.fit_transform(X[col])
encoders[col] = le
# 3. 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 4. 训练模型
print("正在训练 XGBoost 模型...")
model = xgb.XGBClassifier(
n_estimators=100,
learning_rate=0.1,
max_depth=5,
use_label_encoder=False,
eval_metric='logloss'
)
model.fit(X_train, y_train)
# 5. 评估模型
y_pred = model.predict(X_test)
y_pred_proba = model.predict_proba(X_test)[:, 1]
print("\n模型评估结果:")
print(f"Accuracy: {accuracy_score(y_test, y_pred):.4f}")
print("\nClassification Report:")
print(classification_report(y_test, y_pred))
# 6. 保存资产
print("\n正在保存模型和预处理工具...")
artifacts = {
'model': model,
'encoders': encoders,
'target_encoder': le_target,
'feature_meta': feature_meta
}
joblib.dump(artifacts, 'model_artifacts.pkl')
# 另外保存一份特征重要性,供参考
importances = model.feature_importances_
feature_names = X.columns
feat_imp_df = pd.DataFrame({'Feature': feature_names, 'Importance': importances})
feat_imp_df = feat_imp_df.sort_values(by='Importance', ascending=False)
print("\n特征重要性 Top 5:")
print(feat_imp_df.head())
print("\n完成!模型已保存为 'model_artifacts.pkl'")