generated from Python-2026Spring/assignment-05-final-project-template
31 lines
937 B
Python
31 lines
937 B
Python
# save_model.py
|
||
import joblib
|
||
from src.data import load_data, split_data
|
||
from src.model import build_preprocessor
|
||
from sklearn.ensemble import RandomForestClassifier
|
||
from sklearn.pipeline import Pipeline
|
||
|
||
if __name__ == "__main__":
|
||
# 加载数据
|
||
df = load_data()
|
||
X_train, X_test, y_train, y_test = split_data(df)
|
||
preprocessor = build_preprocessor(X_train)
|
||
|
||
# 训练调优后的随机森林模型
|
||
model = Pipeline(steps=[
|
||
('preprocessor', preprocessor),
|
||
('classifier', RandomForestClassifier(
|
||
n_estimators=200,
|
||
max_depth=15,
|
||
min_samples_split=8,
|
||
min_samples_leaf=4,
|
||
class_weight='balanced_subsample',
|
||
random_state=42,
|
||
n_jobs=-1
|
||
))
|
||
])
|
||
model.fit(X_train, y_train)
|
||
|
||
# 保存模型
|
||
joblib.dump(model, 'telco_churn_model.pkl')
|
||
print("✅ 模型已保存为:telco_churn_model.pkl") |