113 lines
4.7 KiB
Python
113 lines
4.7 KiB
Python
import polars as pl
|
|
import numpy as np
|
|
from typing import Tuple, Dict, List, Optional
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CreditCardDataProcessor:
|
|
def __init__(self, file_path: str):
|
|
self.file_path = file_path
|
|
self.data: Optional[pl.DataFrame] = None
|
|
self.train_data: Optional[pl.DataFrame] = None
|
|
self.test_data: Optional[pl.DataFrame] = None
|
|
self.train_features: Optional[np.ndarray] = None
|
|
self.train_labels: Optional[np.ndarray] = None
|
|
self.test_features: Optional[np.ndarray] = None
|
|
self.test_labels: Optional[np.ndarray] = None
|
|
|
|
def load_data(self) -> None:
|
|
logger.info(f"加载数据集: {self.file_path}")
|
|
try:
|
|
self.data = pl.read_csv(
|
|
self.file_path,
|
|
schema_overrides={"Time": pl.Float64}
|
|
)
|
|
logger.info(f"数据集加载成功,形状: {self.data.shape}")
|
|
fraud_count = self.data.filter(pl.col("Class") == 1).height
|
|
normal_count = self.data.filter(pl.col("Class") == 0).height
|
|
logger.info(f"欺诈交易数量: {fraud_count}, 非欺诈交易数量: {normal_count}")
|
|
except Exception as e:
|
|
logger.error(f"加载数据失败: {e}")
|
|
raise
|
|
|
|
def validate_data(self) -> None:
|
|
logger.info("开始数据验证...")
|
|
missing_values = self.data.null_count()
|
|
total_missing = missing_values.sum_horizontal().item()
|
|
if total_missing > 0:
|
|
logger.warning(f"发现缺失值: {total_missing} 个")
|
|
else:
|
|
logger.info("无缺失值,数据完整性良好")
|
|
|
|
class_dist = self.data.group_by("Class").agg(pl.len().alias("count")).to_dict()
|
|
logger.info(f"标签分布: {class_dist}")
|
|
|
|
def split_data_by_time(self, test_ratio: float = 0.2) -> Tuple[pl.DataFrame, pl.DataFrame]:
|
|
logger.info(f"按照时间顺序划分数据集,测试集比例: {test_ratio}")
|
|
sorted_data = self.data.sort("Time")
|
|
split_index = int(sorted_data.height * (1 - test_ratio))
|
|
self.train_data = sorted_data[:split_index]
|
|
self.test_data = sorted_data[split_index:]
|
|
|
|
logger.info(f"训练集形状: {self.train_data.shape}, 测试集形状: {self.test_data.shape}")
|
|
|
|
train_max_time = self.train_data["Time"].max()
|
|
test_min_time = self.test_data["Time"].min()
|
|
logger.info(f"训练集最大时间: {train_max_time}, 测试集最小时间: {test_min_time}")
|
|
|
|
if train_max_time <= test_min_time:
|
|
logger.info("时间划分正确,训练集时间早于测试集")
|
|
else:
|
|
logger.warning("时间划分存在问题,训练集时间晚于测试集")
|
|
|
|
return self.train_data, self.test_data
|
|
|
|
def prepare_features_labels(self, feature_cols: Optional[List[str]] = None, label_col: str = "Class") -> None:
|
|
logger.info("准备特征和标签...")
|
|
if feature_cols is None:
|
|
feature_cols = [col for col in self.data.columns if col != label_col]
|
|
|
|
logger.info(f"使用的特征列: {feature_cols}")
|
|
|
|
self.train_features = self.train_data.select(feature_cols).to_numpy()
|
|
self.train_labels = self.train_data.select(label_col).to_numpy().flatten()
|
|
self.test_features = self.test_data.select(feature_cols).to_numpy()
|
|
self.test_labels = self.test_data.select(label_col).to_numpy().flatten()
|
|
|
|
logger.info(f"训练特征形状: {self.train_features.shape}, 训练标签形状: {self.train_labels.shape}")
|
|
logger.info(f"测试特征形状: {self.test_features.shape}, 测试标签形状: {self.test_labels.shape}")
|
|
|
|
def get_statistics(self) -> Dict[str, any]:
|
|
if self.data is None:
|
|
self.load_data()
|
|
|
|
stats = {
|
|
"总记录数": self.data.height,
|
|
"特征数": len([col for col in self.data.columns if col != "Class"]),
|
|
"欺诈交易数": self.data.filter(pl.col("Class") == 1).height,
|
|
"非欺诈交易数": self.data.filter(pl.col("Class") == 0).height,
|
|
"不平衡比例": self.data.filter(pl.col("Class") == 0).height / self.data.filter(pl.col("Class") == 1).height
|
|
}
|
|
return stats
|
|
|
|
|
|
def load_data(file_path: str = "data/creditcard.csv") -> CreditCardDataProcessor:
|
|
processor = CreditCardDataProcessor(file_path)
|
|
processor.load_data()
|
|
processor.validate_data()
|
|
processor.split_data_by_time()
|
|
processor.prepare_features_labels()
|
|
return processor
|
|
|
|
|
|
if __name__ == "__main__":
|
|
processor = load_data()
|
|
stats = processor.get_statistics()
|
|
print("\n=== 数据集统计信息 ===")
|
|
for key, value in stats.items():
|
|
print(f"{key}: {value}")
|