67 lines
2.2 KiB
Python
67 lines
2.2 KiB
Python
|
|
import sys
|
||
|
|
import os
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import patch
|
||
|
|
|
||
|
|
# Ensure src is in path
|
||
|
|
sys.path.append(os.getcwd())
|
||
|
|
|
||
|
|
from src.infer import predict_pass_prob, explain_prediction, load_model
|
||
|
|
|
||
|
|
# We need a fixture to create a valid model file for inference
|
||
|
|
@pytest.fixture(scope="module")
|
||
|
|
def train_dummy_model(tmp_path_factory):
|
||
|
|
"""Trains a quick dummy model and saves it to a temp dir."""
|
||
|
|
models_dir = tmp_path_factory.mktemp("models")
|
||
|
|
model_path = models_dir / "model.pkl"
|
||
|
|
|
||
|
|
# We reuse the logic from src.train but point to our temp path
|
||
|
|
# OR we can just manually create a pipeline and save it
|
||
|
|
# Reusing src.train is better integration testing
|
||
|
|
from src.train import get_pipeline, generate_data, preprocess_data
|
||
|
|
import joblib
|
||
|
|
|
||
|
|
df = generate_data(n_samples=20)
|
||
|
|
df = preprocess_data(df)
|
||
|
|
|
||
|
|
X = df.drop(columns=["is_pass"])
|
||
|
|
y = df["is_pass"]
|
||
|
|
|
||
|
|
pipeline = get_pipeline("rf")
|
||
|
|
pipeline.fit(X, y)
|
||
|
|
|
||
|
|
joblib.dump(pipeline, model_path)
|
||
|
|
|
||
|
|
return str(model_path)
|
||
|
|
|
||
|
|
@patch("src.infer._MODEL", None) # Reset global cached model
|
||
|
|
def test_predict_pass_prob(train_dummy_model):
|
||
|
|
"""Test prediction using the dummy trained model."""
|
||
|
|
|
||
|
|
with patch("src.infer.MODEL_PATH", train_dummy_model):
|
||
|
|
proba = predict_pass_prob(
|
||
|
|
study_hours=5.0,
|
||
|
|
sleep_hours=7.0,
|
||
|
|
attendance_rate=0.9,
|
||
|
|
stress_level=3,
|
||
|
|
study_type="Self"
|
||
|
|
)
|
||
|
|
assert 0.0 <= proba <= 1.0
|
||
|
|
|
||
|
|
@patch("src.infer._MODEL", None) # Reset global cached model
|
||
|
|
def test_explain_prediction(train_dummy_model):
|
||
|
|
"""Test explanation generation."""
|
||
|
|
|
||
|
|
with patch("src.infer.MODEL_PATH", train_dummy_model):
|
||
|
|
explanation = explain_prediction()
|
||
|
|
assert isinstance(explanation, str)
|
||
|
|
assert "模型特征重要性排名" in explanation
|
||
|
|
|
||
|
|
@patch("src.infer._MODEL", None)
|
||
|
|
def test_load_model_missing():
|
||
|
|
"""Test error handling when model is missing."""
|
||
|
|
with patch("src.infer.MODEL_PATH", "non_existent_path/model.pkl"):
|
||
|
|
# Should raise FileNotFoundError or be handled
|
||
|
|
with pytest.raises(FileNotFoundError):
|
||
|
|
predict_pass_prob(1,1,1,1,"Self") # This calls load_model internally
|