53 lines
1.7 KiB
Python
53 lines
1.7 KiB
Python
import sys
|
|
import os
|
|
import pandas as pd
|
|
import numpy as np
|
|
import pytest
|
|
|
|
# Ensure src is in path
|
|
sys.path.append(os.getcwd())
|
|
|
|
from src.data import generate_data, preprocess_data
|
|
|
|
def test_generate_data_structure():
|
|
"""Test if generate_data returns a DataFrame with correct shape and columns."""
|
|
df = generate_data(n_samples=50)
|
|
|
|
assert isinstance(df, pd.DataFrame)
|
|
assert len(df) == 50
|
|
|
|
expected_cols = [
|
|
"study_hours", "sleep_hours", "attendance_rate",
|
|
"study_type", "stress_level", "is_pass"
|
|
]
|
|
for col in expected_cols:
|
|
assert col in df.columns
|
|
|
|
def test_generate_data_content_range():
|
|
"""Test if generated data falls within expected value ranges."""
|
|
df = generate_data(n_samples=50)
|
|
|
|
assert df["study_hours"].min() >= 0
|
|
assert df["study_hours"].max() <= 20 # Based on generation logic (0-15 actually, but safely below 20)
|
|
assert df["sleep_hours"].min() >= 0
|
|
assert df["stress_level"].between(1, 5).all()
|
|
assert df["is_pass"].isin([0, 1]).all()
|
|
|
|
def test_generate_data_missing_values():
|
|
"""Test if generate_data creates missing values as expected (it has random logic)."""
|
|
# Generate enough samples to likely get nans
|
|
df = generate_data(n_samples=500, random_seed=42)
|
|
# Check if we have nans in specific columns that are supposed to have them
|
|
# In source: attendance_rate has 5% chance of nan
|
|
assert df["attendance_rate"].isnull().sum() >= 0
|
|
|
|
def test_preprocess_data():
|
|
"""Test basic preprocessing (deduplication)."""
|
|
df = pd.DataFrame({
|
|
"a": [1, 2, 2, 3],
|
|
"b": [1, 2, 2, 3]
|
|
})
|
|
|
|
clean_df = preprocess_data(df)
|
|
assert len(clean_df) == 3
|