CourseDesign/tests/test_model.py
2026-01-09 14:30:23 +08:00

33 lines
782 B
Python

import os
import joblib
from src.infer import predict_pass_prob
from src.train import MODEL_PATH, train
def test_train_creates_model():
# 确保模型不存在或被覆盖
if os.path.exists(MODEL_PATH):
os.remove(MODEL_PATH)
train()
assert os.path.exists(MODEL_PATH)
model = joblib.load(MODEL_PATH)
assert model is not None
def test_inference():
# 确保模型存在
if not os.path.exists(MODEL_PATH):
train()
# 高概率情况 (大量学习/睡眠/出勤 + Group学习 + 低压力)
prob_high = predict_pass_prob(15, 8, 1.0, 1, "Group")
assert prob_high > 0.5
# 低概率情况 (不学习/不睡/缺勤 + 在线 + 高压力)
prob_low = predict_pass_prob(0, 3, 0.0, 5, "Online")
assert prob_low < 0.5