112 lines
3.0 KiB
Python
112 lines
3.0 KiB
Python
"""数据模块测试
|
|
|
|
测试 Polars 数据生成、Pandera 校验和预处理功能。
|
|
"""
|
|
|
|
import polars as pl
|
|
import pytest
|
|
|
|
from src.data import (
|
|
CleanStudentDataSchema,
|
|
RawStudentDataSchema,
|
|
generate_data,
|
|
get_feature_columns,
|
|
preprocess_data,
|
|
validate_clean_data,
|
|
validate_raw_data,
|
|
)
|
|
|
|
|
|
def test_generate_data_structure():
|
|
"""测试生成数据的结构是否正确"""
|
|
df = generate_data(n_samples=50)
|
|
|
|
assert isinstance(df, pl.DataFrame)
|
|
assert len(df) == 50
|
|
|
|
expected_cols = [
|
|
"study_hours",
|
|
"sleep_hours",
|
|
"attendance_rate",
|
|
"study_type",
|
|
"stress_level",
|
|
"is_pass",
|
|
]
|
|
for col in expected_cols:
|
|
assert col in df.columns
|
|
|
|
|
|
def test_generate_data_content_range():
|
|
"""测试生成数据的值范围是否正确"""
|
|
df = generate_data(n_samples=50)
|
|
|
|
assert df["study_hours"].min() >= 0
|
|
assert df["study_hours"].max() <= 20
|
|
assert df["sleep_hours"].min() >= 0
|
|
assert df["stress_level"].min() >= 1
|
|
assert df["stress_level"].max() <= 5
|
|
assert df["is_pass"].is_in([0, 1]).all()
|
|
|
|
|
|
def test_generate_data_missing_values():
|
|
"""测试数据是否包含预期的缺失值"""
|
|
df = generate_data(n_samples=500, random_seed=42)
|
|
# attendance_rate 有 5% 概率为 null
|
|
null_count = df["attendance_rate"].null_count()
|
|
assert null_count >= 0
|
|
|
|
|
|
def test_validate_raw_data():
|
|
"""测试原始数据 Schema 校验(宽松模式)"""
|
|
df = generate_data(n_samples=50)
|
|
# 应该能通过校验,即使有缺失值
|
|
validated = validate_raw_data(df)
|
|
assert isinstance(validated, pl.DataFrame)
|
|
|
|
|
|
def test_validate_clean_data():
|
|
"""测试清洗后数据 Schema 校验(严格模式)"""
|
|
df = generate_data(n_samples=50)
|
|
df_clean = df.drop_nulls()
|
|
validated = validate_clean_data(df_clean)
|
|
assert isinstance(validated, pl.DataFrame)
|
|
|
|
|
|
def test_preprocess_data_removes_nulls():
|
|
"""测试预处理是否删除缺失值"""
|
|
df = generate_data(n_samples=500, random_seed=42)
|
|
null_before = df["attendance_rate"].null_count()
|
|
|
|
df_clean = preprocess_data(df, validate=True)
|
|
null_after = df_clean["attendance_rate"].null_count()
|
|
|
|
assert null_after == 0
|
|
assert len(df_clean) <= len(df)
|
|
|
|
|
|
def test_preprocess_data_removes_duplicates():
|
|
"""测试去重预处理"""
|
|
df = pl.DataFrame({
|
|
"study_hours": [1.0, 2.0, 2.0, 3.0],
|
|
"sleep_hours": [7.0, 7.0, 7.0, 7.0],
|
|
"attendance_rate": [0.8, 0.8, 0.8, 0.8],
|
|
"stress_level": [1, 2, 2, 3],
|
|
"study_type": ["Self", "Self", "Self", "Self"],
|
|
"is_pass": [0, 1, 1, 1],
|
|
})
|
|
clean_df = preprocess_data(df, validate=True)
|
|
assert len(clean_df) == 3
|
|
|
|
|
|
def test_get_feature_columns():
|
|
"""测试特征列获取"""
|
|
num_feats, cat_feats = get_feature_columns()
|
|
assert "study_hours" in num_feats
|
|
assert "study_type" in cat_feats
|
|
|
|
|
|
def test_schema_classes_exist():
|
|
"""测试 Schema 类是否可用"""
|
|
assert RawStudentDataSchema is not None
|
|
assert CleanStudentDataSchema is not None
|