CourseDesign/tests/test_infer.py

67 lines
2.2 KiB
Python
Raw Normal View History

import sys
import os
import pytest
from unittest.mock import patch
# Ensure src is in path
sys.path.append(os.getcwd())
from src.infer import predict_pass_prob, explain_prediction, load_model
# We need a fixture to create a valid model file for inference
@pytest.fixture(scope="module")
def train_dummy_model(tmp_path_factory):
"""Trains a quick dummy model and saves it to a temp dir."""
models_dir = tmp_path_factory.mktemp("models")
model_path = models_dir / "model.pkl"
# We reuse the logic from src.train but point to our temp path
# OR we can just manually create a pipeline and save it
# Reusing src.train is better integration testing
from src.train import get_pipeline, generate_data, preprocess_data
import joblib
df = generate_data(n_samples=20)
df = preprocess_data(df)
X = df.drop(columns=["is_pass"])
y = df["is_pass"]
pipeline = get_pipeline("rf")
pipeline.fit(X, y)
joblib.dump(pipeline, model_path)
return str(model_path)
@patch("src.infer._MODEL", None) # Reset global cached model
def test_predict_pass_prob(train_dummy_model):
"""Test prediction using the dummy trained model."""
with patch("src.infer.MODEL_PATH", train_dummy_model):
proba = predict_pass_prob(
study_hours=5.0,
sleep_hours=7.0,
attendance_rate=0.9,
stress_level=3,
study_type="Self"
)
assert 0.0 <= proba <= 1.0
@patch("src.infer._MODEL", None) # Reset global cached model
def test_explain_prediction(train_dummy_model):
"""Test explanation generation."""
with patch("src.infer.MODEL_PATH", train_dummy_model):
explanation = explain_prediction()
assert isinstance(explanation, str)
assert "模型特征重要性排名" in explanation
@patch("src.infer._MODEL", None)
def test_load_model_missing():
"""Test error handling when model is missing."""
with patch("src.infer.MODEL_PATH", "non_existent_path/model.pkl"):
# Should raise FileNotFoundError or be handled
with pytest.raises(FileNotFoundError):
predict_pass_prob(1,1,1,1,"Self") # This calls load_model internally