test01/src/data.py

113 lines
4.7 KiB
Python
Raw Normal View History

import polars as pl
import numpy as np
from typing import Tuple, Dict, List, Optional
import logging
from pathlib import Path
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class CreditCardDataProcessor:
def __init__(self, file_path: str):
self.file_path = file_path
self.data: Optional[pl.DataFrame] = None
self.train_data: Optional[pl.DataFrame] = None
self.test_data: Optional[pl.DataFrame] = None
self.train_features: Optional[np.ndarray] = None
self.train_labels: Optional[np.ndarray] = None
self.test_features: Optional[np.ndarray] = None
self.test_labels: Optional[np.ndarray] = None
def load_data(self) -> None:
logger.info(f"加载数据集: {self.file_path}")
try:
self.data = pl.read_csv(
self.file_path,
schema_overrides={"Time": pl.Float64}
)
logger.info(f"数据集加载成功,形状: {self.data.shape}")
fraud_count = self.data.filter(pl.col("Class") == 1).height
normal_count = self.data.filter(pl.col("Class") == 0).height
logger.info(f"欺诈交易数量: {fraud_count}, 非欺诈交易数量: {normal_count}")
except Exception as e:
logger.error(f"加载数据失败: {e}")
raise
def validate_data(self) -> None:
logger.info("开始数据验证...")
missing_values = self.data.null_count()
total_missing = missing_values.sum_horizontal().item()
if total_missing > 0:
logger.warning(f"发现缺失值: {total_missing}")
else:
logger.info("无缺失值,数据完整性良好")
class_dist = self.data.group_by("Class").agg(pl.len().alias("count")).to_dict()
logger.info(f"标签分布: {class_dist}")
def split_data_by_time(self, test_ratio: float = 0.2) -> Tuple[pl.DataFrame, pl.DataFrame]:
logger.info(f"按照时间顺序划分数据集,测试集比例: {test_ratio}")
sorted_data = self.data.sort("Time")
split_index = int(sorted_data.height * (1 - test_ratio))
self.train_data = sorted_data[:split_index]
self.test_data = sorted_data[split_index:]
logger.info(f"训练集形状: {self.train_data.shape}, 测试集形状: {self.test_data.shape}")
train_max_time = self.train_data["Time"].max()
test_min_time = self.test_data["Time"].min()
logger.info(f"训练集最大时间: {train_max_time}, 测试集最小时间: {test_min_time}")
if train_max_time <= test_min_time:
logger.info("时间划分正确,训练集时间早于测试集")
else:
logger.warning("时间划分存在问题,训练集时间晚于测试集")
return self.train_data, self.test_data
def prepare_features_labels(self, feature_cols: Optional[List[str]] = None, label_col: str = "Class") -> None:
logger.info("准备特征和标签...")
if feature_cols is None:
feature_cols = [col for col in self.data.columns if col != label_col]
logger.info(f"使用的特征列: {feature_cols}")
self.train_features = self.train_data.select(feature_cols).to_numpy()
self.train_labels = self.train_data.select(label_col).to_numpy().flatten()
self.test_features = self.test_data.select(feature_cols).to_numpy()
self.test_labels = self.test_data.select(label_col).to_numpy().flatten()
logger.info(f"训练特征形状: {self.train_features.shape}, 训练标签形状: {self.train_labels.shape}")
logger.info(f"测试特征形状: {self.test_features.shape}, 测试标签形状: {self.test_labels.shape}")
def get_statistics(self) -> Dict[str, any]:
if self.data is None:
self.load_data()
stats = {
"总记录数": self.data.height,
"特征数": len([col for col in self.data.columns if col != "Class"]),
"欺诈交易数": self.data.filter(pl.col("Class") == 1).height,
"非欺诈交易数": self.data.filter(pl.col("Class") == 0).height,
"不平衡比例": self.data.filter(pl.col("Class") == 0).height / self.data.filter(pl.col("Class") == 1).height
}
return stats
def load_data(file_path: str = "data/creditcard.csv") -> CreditCardDataProcessor:
processor = CreditCardDataProcessor(file_path)
processor.load_data()
processor.validate_data()
processor.split_data_by_time()
processor.prepare_features_labels()
return processor
if __name__ == "__main__":
processor = load_data()
stats = processor.get_statistics()
print("\n=== 数据集统计信息 ===")
for key, value in stats.items():
print(f"{key}: {value}")