test01/src/data.py
2311020116lhh b6aef53ef0 feat: 初始化信用卡欺诈检测系统项目
- 添加项目基础结构,包括数据模型、训练、推理和Agent模块
- 实现数据处理、特征工程和模型训练功能
- 添加测试用例和文档说明
- 配置项目依赖和环境变量
2026-01-15 16:20:26 +08:00

113 lines
4.7 KiB
Python

import polars as pl
import numpy as np
from typing import Tuple, Dict, List, Optional
import logging
from pathlib import Path
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class CreditCardDataProcessor:
def __init__(self, file_path: str):
self.file_path = file_path
self.data: Optional[pl.DataFrame] = None
self.train_data: Optional[pl.DataFrame] = None
self.test_data: Optional[pl.DataFrame] = None
self.train_features: Optional[np.ndarray] = None
self.train_labels: Optional[np.ndarray] = None
self.test_features: Optional[np.ndarray] = None
self.test_labels: Optional[np.ndarray] = None
def load_data(self) -> None:
logger.info(f"加载数据集: {self.file_path}")
try:
self.data = pl.read_csv(
self.file_path,
schema_overrides={"Time": pl.Float64}
)
logger.info(f"数据集加载成功,形状: {self.data.shape}")
fraud_count = self.data.filter(pl.col("Class") == 1).height
normal_count = self.data.filter(pl.col("Class") == 0).height
logger.info(f"欺诈交易数量: {fraud_count}, 非欺诈交易数量: {normal_count}")
except Exception as e:
logger.error(f"加载数据失败: {e}")
raise
def validate_data(self) -> None:
logger.info("开始数据验证...")
missing_values = self.data.null_count()
total_missing = missing_values.sum_horizontal().item()
if total_missing > 0:
logger.warning(f"发现缺失值: {total_missing}")
else:
logger.info("无缺失值,数据完整性良好")
class_dist = self.data.group_by("Class").agg(pl.len().alias("count")).to_dict()
logger.info(f"标签分布: {class_dist}")
def split_data_by_time(self, test_ratio: float = 0.2) -> Tuple[pl.DataFrame, pl.DataFrame]:
logger.info(f"按照时间顺序划分数据集,测试集比例: {test_ratio}")
sorted_data = self.data.sort("Time")
split_index = int(sorted_data.height * (1 - test_ratio))
self.train_data = sorted_data[:split_index]
self.test_data = sorted_data[split_index:]
logger.info(f"训练集形状: {self.train_data.shape}, 测试集形状: {self.test_data.shape}")
train_max_time = self.train_data["Time"].max()
test_min_time = self.test_data["Time"].min()
logger.info(f"训练集最大时间: {train_max_time}, 测试集最小时间: {test_min_time}")
if train_max_time <= test_min_time:
logger.info("时间划分正确,训练集时间早于测试集")
else:
logger.warning("时间划分存在问题,训练集时间晚于测试集")
return self.train_data, self.test_data
def prepare_features_labels(self, feature_cols: Optional[List[str]] = None, label_col: str = "Class") -> None:
logger.info("准备特征和标签...")
if feature_cols is None:
feature_cols = [col for col in self.data.columns if col != label_col]
logger.info(f"使用的特征列: {feature_cols}")
self.train_features = self.train_data.select(feature_cols).to_numpy()
self.train_labels = self.train_data.select(label_col).to_numpy().flatten()
self.test_features = self.test_data.select(feature_cols).to_numpy()
self.test_labels = self.test_data.select(label_col).to_numpy().flatten()
logger.info(f"训练特征形状: {self.train_features.shape}, 训练标签形状: {self.train_labels.shape}")
logger.info(f"测试特征形状: {self.test_features.shape}, 测试标签形状: {self.test_labels.shape}")
def get_statistics(self) -> Dict[str, any]:
if self.data is None:
self.load_data()
stats = {
"总记录数": self.data.height,
"特征数": len([col for col in self.data.columns if col != "Class"]),
"欺诈交易数": self.data.filter(pl.col("Class") == 1).height,
"非欺诈交易数": self.data.filter(pl.col("Class") == 0).height,
"不平衡比例": self.data.filter(pl.col("Class") == 0).height / self.data.filter(pl.col("Class") == 1).height
}
return stats
def load_data(file_path: str = "data/creditcard.csv") -> CreditCardDataProcessor:
processor = CreditCardDataProcessor(file_path)
processor.load_data()
processor.validate_data()
processor.split_data_by_time()
processor.prepare_features_labels()
return processor
if __name__ == "__main__":
processor = load_data()
stats = processor.get_statistics()
print("\n=== 数据集统计信息 ===")
for key, value in stats.items():
print(f"{key}: {value}")