"""训练模块测试""" from unittest.mock import patch import pytest from sklearn.pipeline import Pipeline from src.train import get_pipeline, train def test_get_pipeline_structure(): """测试 Pipeline 结构""" pipeline = get_pipeline("rf") assert isinstance(pipeline, Pipeline) assert "preprocessor" in pipeline.named_steps assert "classifier" in pipeline.named_steps def test_get_pipeline_lr(): """测试逻辑回归 Pipeline""" pipeline = get_pipeline("lr") assert isinstance(pipeline, Pipeline) def test_train_function_runs(tmp_path): """测试训练函数能正常运行""" models_dir = tmp_path / "models" model_path = models_dir / "model.pkl" with ( patch("src.train.MODELS_DIR", models_dir), patch("src.train.MODEL_PATH", model_path), patch("src.train.generate_data") as mock_gen, ): from src.data import generate_data real_small_df = generate_data(n_samples=20) mock_gen.return_value = real_small_df try: train() except Exception as e: pytest.fail(f"Train function failed: {e}") assert model_path.exists()