import sys import os import pytest from unittest.mock import patch # Ensure src is in path sys.path.append(os.getcwd()) from src.infer import predict_pass_prob, explain_prediction, load_model # We need a fixture to create a valid model file for inference @pytest.fixture(scope="module") def train_dummy_model(tmp_path_factory): """Trains a quick dummy model and saves it to a temp dir.""" models_dir = tmp_path_factory.mktemp("models") model_path = models_dir / "model.pkl" # We reuse the logic from src.train but point to our temp path # OR we can just manually create a pipeline and save it # Reusing src.train is better integration testing from src.train import get_pipeline, generate_data, preprocess_data import joblib df = generate_data(n_samples=20) df = preprocess_data(df) X = df.drop(columns=["is_pass"]) y = df["is_pass"] pipeline = get_pipeline("rf") pipeline.fit(X, y) joblib.dump(pipeline, model_path) return str(model_path) @patch("src.infer._MODEL", None) # Reset global cached model def test_predict_pass_prob(train_dummy_model): """Test prediction using the dummy trained model.""" with patch("src.infer.MODEL_PATH", train_dummy_model): proba = predict_pass_prob( study_hours=5.0, sleep_hours=7.0, attendance_rate=0.9, stress_level=3, study_type="Self" ) assert 0.0 <= proba <= 1.0 @patch("src.infer._MODEL", None) # Reset global cached model def test_explain_prediction(train_dummy_model): """Test explanation generation.""" with patch("src.infer.MODEL_PATH", train_dummy_model): explanation = explain_prediction() assert isinstance(explanation, str) assert "模型特征重要性排名" in explanation @patch("src.infer._MODEL", None) def test_load_model_missing(): """Test error handling when model is missing.""" with patch("src.infer.MODEL_PATH", "non_existent_path/model.pkl"): # Should raise FileNotFoundError or be handled with pytest.raises(FileNotFoundError): predict_pass_prob(1,1,1,1,"Self") # This calls load_model internally