CourseDesign/tests/test_train.py

49 lines
1.7 KiB
Python

import sys
import os
import pytest
from sklearn.pipeline import Pipeline
from unittest.mock import patch, MagicMock
# Ensure src is in path
sys.path.append(os.getcwd())
from src.train import get_pipeline, train
def test_get_pipeline_structure():
"""Test if get_pipeline returns a valid Scikit-learn pipeline."""
pipeline = get_pipeline("rf")
assert isinstance(pipeline, Pipeline)
assert "preprocessor" in pipeline.named_steps
assert "classifier" in pipeline.named_steps
def test_train_function_runs(tmp_path):
"""
Test if the train function runs without errors.
We mock generate_models to use a temp dir and run with small data.
"""
# Create a temporary directory for models
models_dir = tmp_path / "models"
model_path = models_dir / "model.pkl"
# Needs to be string for some os.path usages if they are strict, but pathlib usually works.
# However, src/train.py uses os.path.join(MODELS_DIR, ...), so we need to patch constants.
with patch("src.train.MODELS_DIR", str(models_dir)), \
patch("src.train.MODEL_PATH", str(model_path)), \
patch("src.train.generate_data") as mock_gen:
# Mock data generation to return a very small dataframe to speed up test
# We need to use real data structure though bc pipeline expects specific columns
from src.data import generate_data
real_small_df = generate_data(n_samples=10)
mock_gen.return_value = real_small_df
# Run training
try:
train()
except Exception as e:
pytest.fail(f"Train function failed with error: {e}")
# Check if model file was created
assert model_path.exists()