117 lines
3.7 KiB
Python
117 lines
3.7 KiB
Python
|
|
import sys
|
||
|
|
import os
|
||
|
|
|
||
|
|
# 修复模块路径问题,让你可以在根目录直接 python src/train.py
|
||
|
|
sys.path.append(os.getcwd())
|
||
|
|
|
||
|
|
import joblib
|
||
|
|
import pandas as pd
|
||
|
|
import numpy as np
|
||
|
|
from sklearn.model_selection import train_test_split
|
||
|
|
from sklearn.pipeline import Pipeline
|
||
|
|
from sklearn.compose import ColumnTransformer
|
||
|
|
from sklearn.impute import SimpleImputer
|
||
|
|
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
||
|
|
from sklearn.linear_model import LogisticRegression
|
||
|
|
from sklearn.ensemble import RandomForestClassifier
|
||
|
|
from sklearn.metrics import classification_report, accuracy_score, f1_score
|
||
|
|
|
||
|
|
from src.data import generate_data, preprocess_data
|
||
|
|
|
||
|
|
MODELS_DIR = "models"
|
||
|
|
MODEL_PATH = os.path.join(MODELS_DIR, "model.pkl")
|
||
|
|
|
||
|
|
def get_pipeline(model_type="rf"):
|
||
|
|
"""
|
||
|
|
构建标准的 Sklearn 处理流水线。
|
||
|
|
1. 数值特征 -> 缺失填充 (均值) -> 标准化
|
||
|
|
2. 类别特征 -> 缺失填充 (众数) -> OneHot编码
|
||
|
|
3. 模型 -> LR 或 RF
|
||
|
|
"""
|
||
|
|
# 定义特征列
|
||
|
|
numeric_features = ["study_hours", "sleep_hours", "attendance_rate", "stress_level"]
|
||
|
|
categorical_features = ["study_type"]
|
||
|
|
|
||
|
|
# 数值处理管道
|
||
|
|
numeric_transformer = Pipeline(steps=[
|
||
|
|
("imputer", SimpleImputer(strategy="mean")),
|
||
|
|
("scaler", StandardScaler())
|
||
|
|
])
|
||
|
|
|
||
|
|
# 类别处理管道
|
||
|
|
categorical_transformer = Pipeline(steps=[
|
||
|
|
("imputer", SimpleImputer(strategy="most_frequent")),
|
||
|
|
("onehot", OneHotEncoder(handle_unknown="ignore"))
|
||
|
|
])
|
||
|
|
|
||
|
|
# 组合预处理
|
||
|
|
preprocessor = ColumnTransformer(
|
||
|
|
transformers=[
|
||
|
|
("num", numeric_transformer, numeric_features),
|
||
|
|
("cat", categorical_transformer, categorical_features)
|
||
|
|
]
|
||
|
|
)
|
||
|
|
|
||
|
|
# 选择模型
|
||
|
|
if model_type == "lr":
|
||
|
|
clf = LogisticRegression(random_state=42)
|
||
|
|
else:
|
||
|
|
clf = RandomForestClassifier(n_estimators=500, max_depth=5, random_state=42)
|
||
|
|
|
||
|
|
return Pipeline(steps=[
|
||
|
|
("preprocessor", preprocessor),
|
||
|
|
("classifier", clf)
|
||
|
|
])
|
||
|
|
|
||
|
|
def train():
|
||
|
|
print(">>> 1. 数据准备")
|
||
|
|
df = generate_data(n_samples=2000)
|
||
|
|
df = preprocess_data(df)
|
||
|
|
|
||
|
|
X = df.drop(columns=["is_pass"])
|
||
|
|
y = df["is_pass"]
|
||
|
|
|
||
|
|
X_train, X_test, y_train, y_test = train_test_split(
|
||
|
|
X, y, test_size=0.2, random_state=42
|
||
|
|
)
|
||
|
|
print(f"训练集大小: {X_train.shape}, 测试集大小: {X_test.shape}")
|
||
|
|
|
||
|
|
print("\n>>> 2. 模型训练与对比")
|
||
|
|
# 模型 A: 逻辑回归 (Baseline)
|
||
|
|
pipe_lr = get_pipeline("lr")
|
||
|
|
pipe_lr.fit(X_train, y_train)
|
||
|
|
y_pred_lr = pipe_lr.predict(X_test)
|
||
|
|
f1_lr = f1_score(y_test, y_pred_lr)
|
||
|
|
print(f"[Baseline - LogisticRegression] F1: {f1_lr:.4f}")
|
||
|
|
|
||
|
|
# 模型 B: 随机森林 (Target)
|
||
|
|
pipe_rf = get_pipeline("rf")
|
||
|
|
pipe_rf.fit(X_train, y_train)
|
||
|
|
y_pred_rf = pipe_rf.predict(X_test)
|
||
|
|
f1_rf = f1_score(y_test, y_pred_rf)
|
||
|
|
print(f"[Target - RandomForest] F1: {f1_rf:.4f}")
|
||
|
|
|
||
|
|
print("\n>>> 3. 如果 RF 更好,则进行详细评估")
|
||
|
|
best_model = pipe_rf
|
||
|
|
print(classification_report(y_test, y_pred_rf))
|
||
|
|
|
||
|
|
print("\n>>> 4. 误差分析 (Error Analysis)")
|
||
|
|
# 找出模型预测错误的样本
|
||
|
|
test_df = X_test.copy()
|
||
|
|
test_df["True Label"] = y_test
|
||
|
|
test_df["Pred Label"] = y_pred_rf
|
||
|
|
|
||
|
|
errors = test_df[test_df["True Label"] != test_df["Pred Label"]]
|
||
|
|
print(f"总计错误样本数: {len(errors)}")
|
||
|
|
if len(errors) > 0:
|
||
|
|
print("典型错误样本预览:")
|
||
|
|
print(errors.head(3))
|
||
|
|
|
||
|
|
print("\n>>> 5. 保存最佳模型")
|
||
|
|
os.makedirs(MODELS_DIR, exist_ok=True)
|
||
|
|
joblib.dump(best_model, MODEL_PATH)
|
||
|
|
print(f"模型 Pipeline 已完整保存至 {MODEL_PATH}")
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
train()
|