"""推理模块测试""" from pathlib import Path from unittest.mock import patch import pytest from src.infer import ( explain_prediction, predict_pass_prob, reset_model_cache, ) @pytest.fixture(scope="module") def train_dummy_model(tmp_path_factory): """训练临时模型用于测试""" models_dir = tmp_path_factory.mktemp("models") model_path = models_dir / "model.pkl" import joblib from src.data import generate_data, preprocess_data from src.train import get_pipeline df = generate_data(n_samples=20) df = preprocess_data(df) # 转换为 pandas df_pandas = df.to_pandas() X = df_pandas.drop(columns=["is_pass"]) y = df_pandas["is_pass"] pipeline = get_pipeline("rf") pipeline.fit(X, y) joblib.dump(pipeline, model_path) return model_path def test_predict_pass_prob(train_dummy_model): """测试预测函数""" reset_model_cache() with patch("src.infer.MODEL_PATH", train_dummy_model): proba = predict_pass_prob( study_hours=5.0, sleep_hours=7.0, attendance_rate=0.9, stress_level=3, study_type="Self", ) assert 0.0 <= proba <= 1.0 def test_explain_prediction(train_dummy_model): """测试解释函数""" reset_model_cache() with patch("src.infer.MODEL_PATH", train_dummy_model): explanation = explain_prediction() assert isinstance(explanation, str) assert "模型特征重要性排名" in explanation def test_load_model_missing(): """测试模型文件不存在时的错误处理""" reset_model_cache() with patch("src.infer.MODEL_PATH", Path("non_existent_path/model.pkl")): with pytest.raises(FileNotFoundError): predict_pass_prob(1, 1, 1, 1, "Self")