generated from Python-2026Spring/assignment-05-final-project-template
31 lines
937 B
Python
31 lines
937 B
Python
|
|
# save_model.py
|
|||
|
|
import joblib
|
|||
|
|
from src.data import load_data, split_data
|
|||
|
|
from src.model import build_preprocessor
|
|||
|
|
from sklearn.ensemble import RandomForestClassifier
|
|||
|
|
from sklearn.pipeline import Pipeline
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
# 加载数据
|
|||
|
|
df = load_data()
|
|||
|
|
X_train, X_test, y_train, y_test = split_data(df)
|
|||
|
|
preprocessor = build_preprocessor(X_train)
|
|||
|
|
|
|||
|
|
# 训练调优后的随机森林模型
|
|||
|
|
model = Pipeline(steps=[
|
|||
|
|
('preprocessor', preprocessor),
|
|||
|
|
('classifier', RandomForestClassifier(
|
|||
|
|
n_estimators=200,
|
|||
|
|
max_depth=15,
|
|||
|
|
min_samples_split=8,
|
|||
|
|
min_samples_leaf=4,
|
|||
|
|
class_weight='balanced_subsample',
|
|||
|
|
random_state=42,
|
|||
|
|
n_jobs=-1
|
|||
|
|
))
|
|||
|
|
])
|
|||
|
|
model.fit(X_train, y_train)
|
|||
|
|
|
|||
|
|
# 保存模型
|
|||
|
|
joblib.dump(model, 'telco_churn_model.pkl')
|
|||
|
|
print("✅ 模型已保存为:telco_churn_model.pkl")
|