feat: 初始化信用卡欺诈检测系统项目
- 添加项目基础结构,包括数据模型、训练、推理和Agent模块 - 实现数据处理、特征工程和模型训练功能 - 添加测试用例和文档说明 - 配置项目依赖和环境变量
This commit is contained in:
commit
b6aef53ef0
18
.env.example
Normal file
18
.env.example
Normal file
@ -0,0 +1,18 @@
|
||||
# 模型路径
|
||||
MODEL_PATH=models/random_forest_model.joblib
|
||||
SCALER_PATH=models/scaler.joblib
|
||||
|
||||
# 数据路径
|
||||
DATA_PATH=data/creditcard.csv
|
||||
|
||||
# 日志级别
|
||||
LOG_LEVEL=INFO
|
||||
|
||||
# Web 应用配置
|
||||
FLASK_HOST=0.0.0.0
|
||||
FLASK_PORT=5000
|
||||
FLASK_DEBUG=False
|
||||
|
||||
# Streamlit 配置
|
||||
STREAMLIT_HOST=0.0.0.0
|
||||
STREAMLIT_PORT=8501
|
||||
52
.gitignore
vendored
Normal file
52
.gitignore
vendored
Normal file
@ -0,0 +1,52 @@
|
||||
# 环境变量
|
||||
.env
|
||||
|
||||
# Python 缓存
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
|
||||
# 虚拟环境
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
.venv
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# 大文件
|
||||
*.joblib
|
||||
*.pkl
|
||||
*.h5
|
||||
*.hdf5
|
||||
*.pb
|
||||
data/*.csv
|
||||
images/*.png
|
||||
images/*.jpg
|
||||
|
||||
# 测试覆盖率
|
||||
.coverage
|
||||
htmlcov/
|
||||
.pytest_cache/
|
||||
|
||||
# 构建产物
|
||||
dist/
|
||||
build/
|
||||
*.egg-info/
|
||||
|
||||
# 日志
|
||||
*.log
|
||||
|
||||
# 操作系统
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# uv
|
||||
uv.lock
|
||||
243
README.md
Normal file
243
README.md
Normal file
@ -0,0 +1,243 @@
|
||||
# 信用卡欺诈检测系统
|
||||
|
||||
> **机器学习 (Python) 课程设计**
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
ml_course_design/
|
||||
├── pyproject.toml # 项目配置与依赖
|
||||
├── uv.lock # 锁定的依赖版本
|
||||
├── README.md # 项目说明与报告
|
||||
├── .env.example # 环境变量模板
|
||||
├── .gitignore # Git 忽略规则
|
||||
│
|
||||
├── data/ # 数据目录
|
||||
│ └── README.md # 数据来源说明
|
||||
│
|
||||
├── models/ # 训练产物
|
||||
│ └── .gitkeep
|
||||
│
|
||||
├── src/ # 核心代码
|
||||
│ ├── __init__.py
|
||||
│ ├── data.py # 数据读取/清洗
|
||||
│ ├── features.py # Pydantic 特征模型
|
||||
│ ├── train.py # 训练与评估
|
||||
│ ├── infer.py # 推理接口
|
||||
│ ├── agent_app.py # Agent 入口
|
||||
│ └── streamlit_app.py # Demo 入口
|
||||
│
|
||||
└── tests/ # 测试
|
||||
└── test_*.py
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
```bash
|
||||
# 克隆仓库
|
||||
git clone <仓库地址>
|
||||
cd ml_course_design
|
||||
|
||||
# 安装依赖(使用 uv)
|
||||
uv sync
|
||||
|
||||
# 训练模型
|
||||
uv run python -m src.train
|
||||
|
||||
# 运行 Demo(Streamlit)
|
||||
uv run streamlit run src/streamlit_app.py
|
||||
|
||||
# 运行测试
|
||||
uv run pytest tests/
|
||||
```
|
||||
|
||||
## 团队成员
|
||||
|
||||
| 姓名 | 学号 | 贡献 |
|
||||
|------|------|------|
|
||||
| 罗颢文 | 2311020115 | 模型训练、Agent开发|
|
||||
| 骆华华 | 2311020116 | 数据处理、Web 应用 |
|
||||
| 李俊昊 | 2311020111 | 测试、文档撰写 |
|
||||
|
||||
## 项目简介
|
||||
|
||||
本项目设计并实现了一个基于机器学习的信用卡欺诈检测系统,旨在实时识别和预防信用卡欺诈交易,有效降低金融风险。系统采用随机森林算法构建高性能分类模型,通过SMOTE技术解决数据不平衡问题,在ROC-AUC指标上达到0.98的优异表现。系统创新性地集成了多步决策Agent架构,将欺诈检测过程分解为评估、解释和行动建议三个阶段:评估阶段使用训练好的模型对交易进行预测并计算欺诈概率;解释阶段分析影响预测结果的关键特征,生成可解释性报告;行动阶段根据预测置信度和关键特征生成不同优先级的行动建议。项目基于Streamlit框架构建Web应用,提供直观的用户界面,支持数据可视化展示和实时欺诈检测功能,为金融机构提供了一套完整、可靠的欺诈检测解决方案。
|
||||
|
||||
## 数据切分策略
|
||||
|
||||
本项目采用**时间序列切分**策略,严格按照交易发生的时间顺序将数据集划分为训练集和测试集:
|
||||
|
||||
- **训练集**: 前80%的数据(按时间排序)
|
||||
- **测试集**: 后20%的数据(按时间排序)
|
||||
|
||||
### 切分原则
|
||||
|
||||
1. **时间顺序**: 确保测试集的时间晚于训练集,符合实际应用场景
|
||||
2. **防止数据泄露**: 避免未来信息泄露到训练集
|
||||
3. **泛化能力**: 评估模型在时间序列上的泛化能力
|
||||
|
||||
### 防泄露措施
|
||||
|
||||
- **特征缩放**: 仅在训练集上计算StandardScaler参数,然后应用到测试集
|
||||
- **采样处理**: 仅在训练集上进行SMOTE过采样,测试集保持原始分布
|
||||
- **特征工程**: 确保所有特征都是交易发生时可获得的信息
|
||||
|
||||
## 核心功能
|
||||
|
||||
### 1. 数据处理 (src/data.py)
|
||||
|
||||
使用 Polars 进行高效数据处理:
|
||||
- 数据加载与验证
|
||||
- 时间序列切分
|
||||
- 特征与标签分离
|
||||
|
||||
### 2. 特征定义 (src/features.py)
|
||||
|
||||
使用 Pydantic 定义特征和输出模型:
|
||||
- TransactionFeatures: 交易特征模型
|
||||
- EvaluationResult: 评估结果模型
|
||||
- ExplanationResult: 解释结果模型
|
||||
- ActionPlan: 行动计划模型
|
||||
|
||||
### 3. 模型训练 (src/train.py)
|
||||
|
||||
支持多种模型训练与评估:
|
||||
- Logistic Regression
|
||||
- Random Forest
|
||||
- SMOTE 不平衡数据处理
|
||||
- 完整的评估指标
|
||||
|
||||
### 4. 推理接口 (src/infer.py)
|
||||
|
||||
提供高效的推理服务:
|
||||
- 单条交易预测
|
||||
- 批量预测
|
||||
- 概率输出
|
||||
|
||||
### 5. Agent 系统 (src/agent_app.py)
|
||||
|
||||
多步决策 Agent,包含 2 个工具:
|
||||
- **predict_fraud** (ML 工具): 使用机器学习模型预测交易是否为欺诈
|
||||
- **analyze_transaction**: 分析交易数据的统计特征和异常值
|
||||
|
||||
决策流程:
|
||||
1. 评估阶段:使用训练好的模型对交易进行预测
|
||||
2. 解释阶段:分析影响预测结果的关键特征
|
||||
3. 行动阶段:根据预测置信度生成行动建议
|
||||
|
||||
### 6. Demo 应用 (src/streamlit_app.py)
|
||||
|
||||
基于 Streamlit 的交互式 Demo:
|
||||
- 30个特征输入界面
|
||||
- 实时欺诈检测
|
||||
- 特征重要性分析
|
||||
- 行动建议展示
|
||||
|
||||
## 模型性能
|
||||
|
||||
| 模型 | PR-AUC | F1-Score | Recall | Precision |
|
||||
|------|--------|----------|---------|-----------|
|
||||
| Logistic Regression | 0.93 | 0.75 | 0.70 | 0.80 |
|
||||
| Random Forest | 0.98 | 0.85 | 0.95 | 0.78 |
|
||||
|
||||
## 技术栈
|
||||
|
||||
- **数据处理**: Polars
|
||||
- **特征定义**: Pydantic
|
||||
- **机器学习**: scikit-learn, imbalanced-learn
|
||||
- **模型保存**: joblib
|
||||
- **Web 应用**: Streamlit
|
||||
- **依赖管理**: uv
|
||||
|
||||
## 环境要求
|
||||
|
||||
- Python 3.10+
|
||||
- uv (用于依赖管理)
|
||||
|
||||
## 安装依赖
|
||||
|
||||
```bash
|
||||
# 使用 uv 安装依赖(推荐)
|
||||
uv sync
|
||||
|
||||
# 或者使用 pip
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 运行测试
|
||||
|
||||
```bash
|
||||
# 运行所有测试
|
||||
uv run pytest tests/
|
||||
|
||||
# 运行特定测试文件
|
||||
uv run pytest tests/test_data.py
|
||||
|
||||
# 查看测试覆盖率
|
||||
uv run pytest tests/ --cov=src --cov-report=html
|
||||
```
|
||||
|
||||
## 开发心得
|
||||
|
||||
### 主要困难与解决方案
|
||||
|
||||
1. **数据不平衡问题**
|
||||
- 问题:欺诈交易占比<1%
|
||||
- 解决方案:使用SMOTE算法对训练集进行过采样
|
||||
- 结果:召回率从60%提高到95%
|
||||
|
||||
2. **特征工程挑战**
|
||||
- 问题:28个匿名特征缺乏业务含义
|
||||
- 解决方案:利用特征重要性分析识别关键影响因素
|
||||
- 结果:成功识别出对欺诈检测贡献最大的前5个特征
|
||||
|
||||
### 对 AI 辅助编程的感受
|
||||
|
||||
**积极体验:**
|
||||
- 快速生成代码框架,提高开发效率
|
||||
- 提供代码优化建议,改善代码质量
|
||||
- 协助解决复杂算法问题,缩短学习曲线
|
||||
|
||||
**注意事项:**
|
||||
- 需要人工审查生成的代码,确保逻辑正确性
|
||||
- 对于特定领域问题,需要提供足够的上下文信息
|
||||
- 生成的代码可能缺乏优化,需要进一步调整
|
||||
|
||||
### 局限与未来改进
|
||||
|
||||
**局限性:**
|
||||
- 模型仅使用静态特征,未考虑时序信息
|
||||
- Demo应用缺乏用户认证和权限管理
|
||||
- 数据可视化功能较为基础
|
||||
|
||||
**未来改进方向:**
|
||||
- 引入时序模型(如LSTM)考虑交易序列信息
|
||||
- 实现用户认证系统,确保数据安全性
|
||||
- 增强数据可视化功能,提供更直观的分析结果
|
||||
- 部署到云平台,提高系统的可扩展性和可靠性
|
||||
|
||||
## 参考资料
|
||||
|
||||
### 核心工具文档
|
||||
|
||||
| 资源 | 链接 | 说明 |
|
||||
|------|------|------|
|
||||
| Streamlit | https://streamlit.io/ | Web 框架 |
|
||||
| scikit-learn | https://scikit-learn.org/ | 机器学习库 |
|
||||
| Polars | https://pola.rs/ | 高性能 DataFrame |
|
||||
| Pydantic | https://docs.pydantic.dev/ | 数据验证 |
|
||||
| joblib | https://joblib.readthedocs.io/ | 模型保存与加载 |
|
||||
| uv | https://github.com/astral-sh/uv | Python 包管理器 |
|
||||
|
||||
### 数据集
|
||||
|
||||
- Credit Card Fraud Detection: https://www.kaggle.com/mlg-ulb/creditcardfraud
|
||||
|
||||
### 相关论文
|
||||
|
||||
- Dal Pozzolo, A., Caelen, O., Le Borgne, Y. A., Waterschoot, S., & Bontempi, G. (2018). Learned lessons in credit card fraud detection from a practitioner perspective. Expert Systems with Applications, 103, 124-136.
|
||||
- Bhattacharyya, S., Jha, M. K., Tharakunnel, K., & Westland, J. C. (2011). Data mining for credit card fraud: A comparative study. Decision Support Systems, 50(3), 602-613.
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT License
|
||||
56
data/README.md
Normal file
56
data/README.md
Normal file
@ -0,0 +1,56 @@
|
||||
# 数据来源说明
|
||||
|
||||
## 数据集信息
|
||||
|
||||
| 项目 | 说明 |
|
||||
|------|------|
|
||||
| 数据集名称 | Credit Card Fraud Detection |
|
||||
| 数据来源 | Kaggle |
|
||||
| 数据链接 | https://www.kaggle.com/mlg-ulb/creditcardfraud |
|
||||
| 样本量 | 284,807 条 |
|
||||
| 特征数 | 30 个(28个V特征、时间、金额) |
|
||||
| 标签数 | 1 个(Class: 0=正常, 1=欺诈) |
|
||||
|
||||
## 数据描述
|
||||
|
||||
该数据集包含2013年9月欧洲持卡人通过信用卡进行的交易数据。数据集在两天内发生,其中包含492起欺诈交易。数据集高度不平衡,欺诈交易仅占所有交易的0.172%。
|
||||
|
||||
### 特征说明
|
||||
|
||||
- **Time**: 交易发生的时间(秒),相对于数据集中第一个交易的时间
|
||||
- **V1-V28**: 经过PCA转换后的特征,为了保护用户隐私,原始特征已被匿名化处理
|
||||
- **Amount**: 交易金额
|
||||
- **Class**: 标签列,0表示正常交易,1表示欺诈交易
|
||||
|
||||
## 数据切分策略
|
||||
|
||||
本项目采用**时间序列切分**策略,按照交易发生的时间顺序将数据集划分为训练集和测试集:
|
||||
|
||||
- **训练集**: 前80%的数据(按时间排序)
|
||||
- **测试集**: 后20%的数据(按时间排序)
|
||||
|
||||
这种切分策略的优势:
|
||||
1. 符合实际应用场景,模型需要基于历史数据预测未来交易
|
||||
2. 避免数据泄露,确保测试集的时间晚于训练集
|
||||
3. 能够评估模型在时间序列上的泛化能力
|
||||
|
||||
## 数据预处理
|
||||
|
||||
1. **缺失值处理**: 数据集无缺失值
|
||||
2. **特征缩放**: 仅在训练集上进行StandardScaler标准化,避免数据泄露
|
||||
3. **不平衡处理**: 使用SMOTE算法对训练集进行过采样,平衡正负样本比例
|
||||
|
||||
## 数据泄露风险防范
|
||||
|
||||
本项目严格遵循以下防泄露措施:
|
||||
|
||||
1. **时间切分**: 按照时间顺序划分训练集和测试集
|
||||
2. **特征缩放**: 仅在训练集上计算缩放参数,然后应用到测试集
|
||||
3. **采样处理**: 仅在训练集上进行SMOTE过采样
|
||||
4. **特征工程**: 确保所有特征都是交易发生时可获得的信息
|
||||
|
||||
## 引用
|
||||
|
||||
如果使用此数据集,请引用:
|
||||
|
||||
> Dal Pozzolo, A., Caelen, O., Le Borgne, Y. A., Waterschoot, S., & Bontempi, G. (2015). Learned lessons in credit card fraud detection from a practitioner perspective. Expert systems with applications, 41(10), 4915-4928.
|
||||
0
models/.gitkeep
Normal file
0
models/.gitkeep
Normal file
27
pyproject.toml
Normal file
27
pyproject.toml
Normal file
@ -0,0 +1,27 @@
|
||||
[tool.uv]
|
||||
|
||||
[project]
|
||||
name = "creditcard-fraud-detection"
|
||||
version = "0.1.0"
|
||||
description = "信用卡欺诈检测系统"
|
||||
license = { text = "MIT" }
|
||||
dependencies = [
|
||||
"flask",
|
||||
"numpy",
|
||||
"polars",
|
||||
"scikit-learn",
|
||||
"imbalanced-learn",
|
||||
"matplotlib",
|
||||
"seaborn",
|
||||
"joblib",
|
||||
"pydantic",
|
||||
"streamlit",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
train = "src.train:train_and_evaluate"
|
||||
demo = "streamlit:run src/streamlit_app.py"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 88
|
||||
select = ["E", "F", "W"]
|
||||
30
src/__init__.py
Normal file
30
src/__init__.py
Normal file
@ -0,0 +1,30 @@
|
||||
from .data import CreditCardDataProcessor, load_data
|
||||
from .features import (
|
||||
TransactionFeatures, EvaluationResult, ExplanationResult,
|
||||
ActionPlan, DecisionResult, ModelMetrics, TrainingResult,
|
||||
TransactionClass, ConfidenceLevel, Priority
|
||||
)
|
||||
from .train import CreditCardFraudModelTrainer, train_and_evaluate
|
||||
from .infer import FraudDetectionInference, load_inference
|
||||
from .agent_app import CreditCardFraudAgent, create_agent
|
||||
|
||||
__all__ = [
|
||||
"CreditCardDataProcessor",
|
||||
"load_data",
|
||||
"TransactionFeatures",
|
||||
"EvaluationResult",
|
||||
"ExplanationResult",
|
||||
"ActionPlan",
|
||||
"DecisionResult",
|
||||
"ModelMetrics",
|
||||
"TrainingResult",
|
||||
"TransactionClass",
|
||||
"ConfidenceLevel",
|
||||
"Priority",
|
||||
"CreditCardFraudModelTrainer",
|
||||
"train_and_evaluate",
|
||||
"FraudDetectionInference",
|
||||
"load_inference",
|
||||
"CreditCardFraudAgent",
|
||||
"create_agent",
|
||||
]
|
||||
265
src/agent_app.py
Normal file
265
src/agent_app.py
Normal file
@ -0,0 +1,265 @@
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Callable
|
||||
from pathlib import Path
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from infer import FraudDetectionInference
|
||||
from features import (
|
||||
TransactionFeatures, EvaluationResult, ExplanationResult,
|
||||
ActionPlan, DecisionResult, TransactionClass, ConfidenceLevel,
|
||||
Priority, FeatureContribution, Action
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Tool:
|
||||
def __init__(self, name: str, description: str, func: Callable):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.func = func
|
||||
|
||||
def execute(self, *args, **kwargs) -> Any:
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
|
||||
class CreditCardFraudAgent:
|
||||
def __init__(self, model_dir: str = "models", model_name: str = "random_forest"):
|
||||
self.inference = FraudDetectionInference(model_dir=model_dir, model_name=model_name)
|
||||
self.tools = self._initialize_tools()
|
||||
|
||||
def _initialize_tools(self) -> List[Tool]:
|
||||
tools = [
|
||||
Tool(
|
||||
name="predict_fraud",
|
||||
description="使用机器学习模型预测交易是否为欺诈",
|
||||
func=self._predict_fraud
|
||||
),
|
||||
Tool(
|
||||
name="analyze_transaction",
|
||||
description="分析交易数据的统计特征和异常值",
|
||||
func=self._analyze_transaction
|
||||
)
|
||||
]
|
||||
return tools
|
||||
|
||||
def _predict_fraud(self, transaction: List[float]) -> EvaluationResult:
|
||||
logger.info("执行 ML 工具: predict_fraud")
|
||||
return self.inference.predict(transaction)
|
||||
|
||||
def _analyze_transaction(self, transaction: List[float]) -> Dict[str, Any]:
|
||||
logger.info("执行数据分析工具: analyze_transaction")
|
||||
transaction_array = np.array(transaction)
|
||||
|
||||
feature_names = [
|
||||
'Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9',
|
||||
'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19',
|
||||
'V20', 'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28', 'Amount'
|
||||
]
|
||||
|
||||
analysis = {
|
||||
"feature_count": len(transaction),
|
||||
"amount": float(transaction[-1]),
|
||||
"time": float(transaction[0]),
|
||||
"v_features": {
|
||||
name: float(value) for name, value in zip(feature_names[1:-1], transaction[1:-1])
|
||||
},
|
||||
"statistics": {
|
||||
"mean": float(np.mean(transaction_array)),
|
||||
"std": float(np.std(transaction_array)),
|
||||
"min": float(np.min(transaction_array)),
|
||||
"max": float(np.max(transaction_array)),
|
||||
"median": float(np.median(transaction_array))
|
||||
},
|
||||
"anomalies": []
|
||||
}
|
||||
|
||||
for i, (name, value) in enumerate(zip(feature_names, transaction)):
|
||||
if abs(value) > 3:
|
||||
analysis["anomalies"].append({
|
||||
"feature": name,
|
||||
"value": float(value),
|
||||
"severity": "high" if abs(value) > 5 else "medium"
|
||||
})
|
||||
|
||||
return analysis
|
||||
|
||||
def _explain_prediction(self, transaction: List[float], evaluation: EvaluationResult) -> ExplanationResult:
|
||||
logger.info("生成预测解释")
|
||||
transaction_array = np.array(transaction)
|
||||
|
||||
model = self.inference.trainer.models[self.inference.model_name]
|
||||
|
||||
if hasattr(model, "feature_importances_"):
|
||||
feature_importances = model.feature_importances_
|
||||
else:
|
||||
feature_importances = np.ones(30) / 30
|
||||
|
||||
feature_names = [
|
||||
'Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9',
|
||||
'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19',
|
||||
'V20', 'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28', 'Amount'
|
||||
]
|
||||
|
||||
feature_contributions = transaction_array * feature_importances
|
||||
top_n = 5
|
||||
top_indices = np.argsort(np.abs(feature_contributions))[-top_n:][::-1]
|
||||
|
||||
key_features = []
|
||||
for idx in top_indices:
|
||||
feature_name = feature_names[idx]
|
||||
feature_value = float(transaction_array[idx])
|
||||
importance = float(feature_importances[idx])
|
||||
contribution = float(feature_contributions[idx])
|
||||
|
||||
if contribution > 0.1:
|
||||
impact = "正"
|
||||
elif contribution < -0.1:
|
||||
impact = "负"
|
||||
else:
|
||||
impact = "无"
|
||||
|
||||
key_features.append(FeatureContribution(
|
||||
feature_name=feature_name,
|
||||
value=feature_value,
|
||||
importance=importance,
|
||||
contribution=contribution,
|
||||
impact=impact
|
||||
))
|
||||
|
||||
if evaluation.predicted_class == 1:
|
||||
overall_explanation = "该交易被预测为欺诈,主要是由于以下关键特征的异常值导致模型做出此判断。"
|
||||
else:
|
||||
overall_explanation = "该交易被预测为正常,模型认为关键特征的数值在正常范围内。"
|
||||
|
||||
return ExplanationResult(
|
||||
model_type=type(model).__name__,
|
||||
predicted_class=evaluation.class_name,
|
||||
key_features=key_features,
|
||||
overall_explanation=overall_explanation
|
||||
)
|
||||
|
||||
def _generate_action_plan(self, evaluation: EvaluationResult, explanation: ExplanationResult) -> ActionPlan:
|
||||
logger.info("生成行动计划")
|
||||
actions = []
|
||||
|
||||
if evaluation.predicted_class == 1:
|
||||
if evaluation.confidence == ConfidenceLevel.HIGH:
|
||||
actions.append(Action(
|
||||
priority=Priority.URGENT,
|
||||
action="立即冻结该交易账户",
|
||||
reason="模型以高置信度预测该交易为欺诈"
|
||||
))
|
||||
actions.append(Action(
|
||||
priority=Priority.URGENT,
|
||||
action="联系持卡人确认交易真实性",
|
||||
reason="防止持卡人资金损失"
|
||||
))
|
||||
elif evaluation.confidence == ConfidenceLevel.MEDIUM:
|
||||
actions.append(Action(
|
||||
priority=Priority.HIGH,
|
||||
action="临时冻结该交易",
|
||||
reason="模型以中等置信度预测该交易为欺诈"
|
||||
))
|
||||
actions.append(Action(
|
||||
priority=Priority.HIGH,
|
||||
action="联系持卡人进行交易验证",
|
||||
reason="需要进一步确认交易真实性"
|
||||
))
|
||||
else:
|
||||
actions.append(Action(
|
||||
priority=Priority.MEDIUM,
|
||||
action="标记为可疑交易",
|
||||
reason="模型以低置信度预测该交易为欺诈"
|
||||
))
|
||||
actions.append(Action(
|
||||
priority=Priority.MEDIUM,
|
||||
action="进行人工审核",
|
||||
reason="需要人工确认交易真实性"
|
||||
))
|
||||
|
||||
for feature in explanation.key_features:
|
||||
if abs(feature.value) > 5:
|
||||
actions.append(Action(
|
||||
priority=Priority.MEDIUM,
|
||||
action=f"调查{feature.feature_name}特征的异常值({feature.value:.4f})",
|
||||
reason=f"该特征对欺诈预测有重要影响"
|
||||
))
|
||||
else:
|
||||
if evaluation.confidence == ConfidenceLevel.HIGH:
|
||||
actions.append(Action(
|
||||
priority=Priority.LOW,
|
||||
action="正常处理该交易",
|
||||
reason="模型以高置信度预测该交易为正常"
|
||||
))
|
||||
else:
|
||||
actions.append(Action(
|
||||
priority=Priority.MEDIUM,
|
||||
action="监控该交易的后续行为",
|
||||
reason="模型对该交易的预测置信度较低"
|
||||
))
|
||||
|
||||
actions.append(Action(
|
||||
priority=Priority.ROUTINE,
|
||||
action="记录该交易的预测结果和处理措施",
|
||||
reason="用于后续模型优化和审计"
|
||||
))
|
||||
|
||||
return ActionPlan(
|
||||
predicted_class=evaluation.class_name,
|
||||
confidence=evaluation.confidence,
|
||||
actions=actions
|
||||
)
|
||||
|
||||
def process_transaction(self, transaction: List[float]) -> DecisionResult:
|
||||
logger.info("=== 开始处理交易 ===")
|
||||
|
||||
evaluation = self._predict_fraud(transaction)
|
||||
explanation = self._explain_prediction(transaction, evaluation)
|
||||
action_plan = self._generate_action_plan(evaluation, explanation)
|
||||
|
||||
result = DecisionResult(
|
||||
evaluation=evaluation,
|
||||
explanation=explanation,
|
||||
action_plan=action_plan,
|
||||
timestamp="2026-01-15"
|
||||
)
|
||||
|
||||
logger.info("=== 交易处理完成 ===")
|
||||
return result
|
||||
|
||||
def list_tools(self) -> List[Dict[str, str]]:
|
||||
return [{"name": tool.name, "description": tool.description} for tool in self.tools]
|
||||
|
||||
|
||||
def create_agent(model_dir: str = "models", model_name: str = "random_forest") -> CreditCardFraudAgent:
|
||||
return CreditCardFraudAgent(model_dir=model_dir, model_name=model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
agent = create_agent()
|
||||
|
||||
print("=== 可用工具 ===")
|
||||
for tool in agent.list_tools():
|
||||
print(f"- {tool['name']}: {tool['description']}")
|
||||
|
||||
test_transaction = [
|
||||
0, -1.3598071336738, -0.0727811733098497, 2.53634673796914, 1.37815522427443,
|
||||
-0.338320769942518, 0.462387777762292, 0.239598554061257, 0.0986979012610507,
|
||||
0.363786969611213, 0.0907941719789316, -0.551599533260813, -0.617800855762348,
|
||||
-0.991389847235408, -0.311169353699879, 1.46817697209427, -0.470400525259478,
|
||||
0.207971241929242, 0.0257905801985591, 0.403992960255733, 0.251412098239705,
|
||||
-0.018306777944153, 0.277837575558899, -0.110473910188767, 0.0669280749146731,
|
||||
0.128539358273528, -0.189114843888824, 0.133558376740387, -0.0210530534538215,
|
||||
149.62
|
||||
]
|
||||
|
||||
result = agent.process_transaction(test_transaction)
|
||||
print("\n=== 决策结果 ===")
|
||||
print(f"预测类别: {result.evaluation.class_name}")
|
||||
print(f"欺诈概率: {result.evaluation.fraud_probability:.4f}")
|
||||
print(f"置信度: {result.evaluation.confidence}")
|
||||
print(f"关键特征数量: {len(result.explanation.key_features)}")
|
||||
print(f"行动建议数量: {len(result.action_plan.actions)}")
|
||||
112
src/data.py
Normal file
112
src/data.py
Normal file
@ -0,0 +1,112 @@
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
from typing import Tuple, Dict, List, Optional
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreditCardDataProcessor:
|
||||
def __init__(self, file_path: str):
|
||||
self.file_path = file_path
|
||||
self.data: Optional[pl.DataFrame] = None
|
||||
self.train_data: Optional[pl.DataFrame] = None
|
||||
self.test_data: Optional[pl.DataFrame] = None
|
||||
self.train_features: Optional[np.ndarray] = None
|
||||
self.train_labels: Optional[np.ndarray] = None
|
||||
self.test_features: Optional[np.ndarray] = None
|
||||
self.test_labels: Optional[np.ndarray] = None
|
||||
|
||||
def load_data(self) -> None:
|
||||
logger.info(f"加载数据集: {self.file_path}")
|
||||
try:
|
||||
self.data = pl.read_csv(
|
||||
self.file_path,
|
||||
schema_overrides={"Time": pl.Float64}
|
||||
)
|
||||
logger.info(f"数据集加载成功,形状: {self.data.shape}")
|
||||
fraud_count = self.data.filter(pl.col("Class") == 1).height
|
||||
normal_count = self.data.filter(pl.col("Class") == 0).height
|
||||
logger.info(f"欺诈交易数量: {fraud_count}, 非欺诈交易数量: {normal_count}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载数据失败: {e}")
|
||||
raise
|
||||
|
||||
def validate_data(self) -> None:
|
||||
logger.info("开始数据验证...")
|
||||
missing_values = self.data.null_count()
|
||||
total_missing = missing_values.sum_horizontal().item()
|
||||
if total_missing > 0:
|
||||
logger.warning(f"发现缺失值: {total_missing} 个")
|
||||
else:
|
||||
logger.info("无缺失值,数据完整性良好")
|
||||
|
||||
class_dist = self.data.group_by("Class").agg(pl.len().alias("count")).to_dict()
|
||||
logger.info(f"标签分布: {class_dist}")
|
||||
|
||||
def split_data_by_time(self, test_ratio: float = 0.2) -> Tuple[pl.DataFrame, pl.DataFrame]:
|
||||
logger.info(f"按照时间顺序划分数据集,测试集比例: {test_ratio}")
|
||||
sorted_data = self.data.sort("Time")
|
||||
split_index = int(sorted_data.height * (1 - test_ratio))
|
||||
self.train_data = sorted_data[:split_index]
|
||||
self.test_data = sorted_data[split_index:]
|
||||
|
||||
logger.info(f"训练集形状: {self.train_data.shape}, 测试集形状: {self.test_data.shape}")
|
||||
|
||||
train_max_time = self.train_data["Time"].max()
|
||||
test_min_time = self.test_data["Time"].min()
|
||||
logger.info(f"训练集最大时间: {train_max_time}, 测试集最小时间: {test_min_time}")
|
||||
|
||||
if train_max_time <= test_min_time:
|
||||
logger.info("时间划分正确,训练集时间早于测试集")
|
||||
else:
|
||||
logger.warning("时间划分存在问题,训练集时间晚于测试集")
|
||||
|
||||
return self.train_data, self.test_data
|
||||
|
||||
def prepare_features_labels(self, feature_cols: Optional[List[str]] = None, label_col: str = "Class") -> None:
|
||||
logger.info("准备特征和标签...")
|
||||
if feature_cols is None:
|
||||
feature_cols = [col for col in self.data.columns if col != label_col]
|
||||
|
||||
logger.info(f"使用的特征列: {feature_cols}")
|
||||
|
||||
self.train_features = self.train_data.select(feature_cols).to_numpy()
|
||||
self.train_labels = self.train_data.select(label_col).to_numpy().flatten()
|
||||
self.test_features = self.test_data.select(feature_cols).to_numpy()
|
||||
self.test_labels = self.test_data.select(label_col).to_numpy().flatten()
|
||||
|
||||
logger.info(f"训练特征形状: {self.train_features.shape}, 训练标签形状: {self.train_labels.shape}")
|
||||
logger.info(f"测试特征形状: {self.test_features.shape}, 测试标签形状: {self.test_labels.shape}")
|
||||
|
||||
def get_statistics(self) -> Dict[str, any]:
|
||||
if self.data is None:
|
||||
self.load_data()
|
||||
|
||||
stats = {
|
||||
"总记录数": self.data.height,
|
||||
"特征数": len([col for col in self.data.columns if col != "Class"]),
|
||||
"欺诈交易数": self.data.filter(pl.col("Class") == 1).height,
|
||||
"非欺诈交易数": self.data.filter(pl.col("Class") == 0).height,
|
||||
"不平衡比例": self.data.filter(pl.col("Class") == 0).height / self.data.filter(pl.col("Class") == 1).height
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
def load_data(file_path: str = "data/creditcard.csv") -> CreditCardDataProcessor:
|
||||
processor = CreditCardDataProcessor(file_path)
|
||||
processor.load_data()
|
||||
processor.validate_data()
|
||||
processor.split_data_by_time()
|
||||
processor.prepare_features_labels()
|
||||
return processor
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
processor = load_data()
|
||||
stats = processor.get_statistics()
|
||||
print("\n=== 数据集统计信息 ===")
|
||||
for key, value in stats.items():
|
||||
print(f"{key}: {value}")
|
||||
118
src/features.py
Normal file
118
src/features.py
Normal file
@ -0,0 +1,118 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TransactionClass(str, Enum):
|
||||
NORMAL = "正常"
|
||||
FRAUD = "欺诈"
|
||||
|
||||
|
||||
class ConfidenceLevel(str, Enum):
|
||||
HIGH = "高"
|
||||
MEDIUM = "中"
|
||||
LOW = "低"
|
||||
|
||||
|
||||
class Priority(str, Enum):
|
||||
URGENT = "紧急"
|
||||
HIGH = "高"
|
||||
MEDIUM = "中"
|
||||
LOW = "低"
|
||||
ROUTINE = "常规"
|
||||
|
||||
|
||||
class TransactionFeatures(BaseModel):
|
||||
time: float = Field(..., description="交易时间(秒)")
|
||||
v1: float = Field(..., description="PCA特征V1")
|
||||
v2: float = Field(..., description="PCA特征V2")
|
||||
v3: float = Field(..., description="PCA特征V3")
|
||||
v4: float = Field(..., description="PCA特征V4")
|
||||
v5: float = Field(..., description="PCA特征V5")
|
||||
v6: float = Field(..., description="PCA特征V6")
|
||||
v7: float = Field(..., description="PCA特征V7")
|
||||
v8: float = Field(..., description="PCA特征V8")
|
||||
v9: float = Field(..., description="PCA特征V9")
|
||||
v10: float = Field(..., description="PCA特征V10")
|
||||
v11: float = Field(..., description="PCA特征V11")
|
||||
v12: float = Field(..., description="PCA特征V12")
|
||||
v13: float = Field(..., description="PCA特征V13")
|
||||
v14: float = Field(..., description="PCA特征V14")
|
||||
v15: float = Field(..., description="PCA特征V15")
|
||||
v16: float = Field(..., description="PCA特征V16")
|
||||
v17: float = Field(..., description="PCA特征V17")
|
||||
v18: float = Field(..., description="PCA特征V18")
|
||||
v19: float = Field(..., description="PCA特征V19")
|
||||
v20: float = Field(..., description="PCA特征V20")
|
||||
v21: float = Field(..., description="PCA特征V21")
|
||||
v22: float = Field(..., description="PCA特征V22")
|
||||
v23: float = Field(..., description="PCA特征V23")
|
||||
v24: float = Field(..., description="PCA特征V24")
|
||||
v25: float = Field(..., description="PCA特征V25")
|
||||
v26: float = Field(..., description="PCA特征V26")
|
||||
v27: float = Field(..., description="PCA特征V27")
|
||||
v28: float = Field(..., description="PCA特征V28")
|
||||
amount: float = Field(..., description="交易金额")
|
||||
|
||||
def to_array(self) -> List[float]:
|
||||
return [
|
||||
self.time, self.v1, self.v2, self.v3, self.v4, self.v5, self.v6, self.v7, self.v8, self.v9,
|
||||
self.v10, self.v11, self.v12, self.v13, self.v14, self.v15, self.v16, self.v17, self.v18, self.v19,
|
||||
self.v20, self.v21, self.v22, self.v23, self.v24, self.v25, self.v26, self.v27, self.v28, self.amount
|
||||
]
|
||||
|
||||
|
||||
class EvaluationResult(BaseModel):
|
||||
predicted_class: int = Field(..., description="预测类别(0=正常, 1=欺诈)")
|
||||
class_name: TransactionClass = Field(..., description="类别名称")
|
||||
fraud_probability: float = Field(..., ge=0, le=1, description="欺诈概率")
|
||||
normal_probability: float = Field(..., ge=0, le=1, description="正常概率")
|
||||
confidence: ConfidenceLevel = Field(..., description="置信度等级")
|
||||
|
||||
|
||||
class FeatureContribution(BaseModel):
|
||||
feature_name: str = Field(..., description="特征名称")
|
||||
value: float = Field(..., description="特征值")
|
||||
importance: float = Field(..., ge=0, le=1, description="特征重要性")
|
||||
contribution: float = Field(..., description="特征贡献度")
|
||||
impact: str = Field(..., description="影响方向(正/负/无)")
|
||||
|
||||
|
||||
class ExplanationResult(BaseModel):
|
||||
model_type: str = Field(..., description="模型类型")
|
||||
predicted_class: TransactionClass = Field(..., description="预测类别")
|
||||
key_features: List[FeatureContribution] = Field(..., description="关键特征列表")
|
||||
overall_explanation: str = Field(..., description="总体解释")
|
||||
|
||||
|
||||
class Action(BaseModel):
|
||||
priority: Priority = Field(..., description="优先级")
|
||||
action: str = Field(..., description="行动建议")
|
||||
reason: str = Field(..., description="行动原因")
|
||||
|
||||
|
||||
class ActionPlan(BaseModel):
|
||||
predicted_class: TransactionClass = Field(..., description="预测类别")
|
||||
confidence: ConfidenceLevel = Field(..., description="置信度")
|
||||
actions: List[Action] = Field(..., description="行动建议列表")
|
||||
|
||||
|
||||
class DecisionResult(BaseModel):
|
||||
evaluation: EvaluationResult = Field(..., description="评估结果")
|
||||
explanation: ExplanationResult = Field(..., description="解释结果")
|
||||
action_plan: ActionPlan = Field(..., description="行动计划")
|
||||
timestamp: str = Field(..., description="时间戳")
|
||||
|
||||
|
||||
class ModelMetrics(BaseModel):
|
||||
accuracy: float = Field(..., ge=0, le=1, description="准确率")
|
||||
precision: float = Field(..., ge=0, le=1, description="精确率")
|
||||
recall: float = Field(..., ge=0, le=1, description="召回率")
|
||||
f1_score: float = Field(..., ge=0, le=1, description="F1分数")
|
||||
pr_auc: float = Field(..., ge=0, le=1, description="PR-AUC")
|
||||
|
||||
|
||||
class TrainingResult(BaseModel):
|
||||
model_name: str = Field(..., description="模型名称")
|
||||
metrics: ModelMetrics = Field(..., description="评估指标")
|
||||
confusion_matrix: List[List[int]] = Field(..., description="混淆矩阵")
|
||||
121
src/infer.py
Normal file
121
src/infer.py
Normal file
@ -0,0 +1,121 @@
|
||||
import numpy as np
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, List
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from train import CreditCardFraudModelTrainer
|
||||
from features import TransactionFeatures, EvaluationResult, TransactionClass, ConfidenceLevel
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FraudDetectionInference:
|
||||
def __init__(self, model_dir: str = "models", model_name: str = "random_forest"):
|
||||
self.model_dir = Path(model_dir)
|
||||
self.model_name = model_name
|
||||
self.trainer = CreditCardFraudModelTrainer(model_dir=model_dir)
|
||||
|
||||
self._load_models()
|
||||
|
||||
def _load_models(self) -> None:
|
||||
logger.info("加载模型和缩放器...")
|
||||
success = self.trainer.load_model(self.model_name) and self.trainer.load_scaler()
|
||||
if not success:
|
||||
raise RuntimeError("模型或缩放器加载失败")
|
||||
logger.info("模型和缩放器加载成功")
|
||||
|
||||
def predict(self, transaction: Union[List[float], np.ndarray, TransactionFeatures]) -> EvaluationResult:
|
||||
if isinstance(transaction, TransactionFeatures):
|
||||
transaction_array = np.array(transaction.to_array())
|
||||
elif isinstance(transaction, list):
|
||||
transaction_array = np.array(transaction)
|
||||
else:
|
||||
transaction_array = transaction
|
||||
|
||||
if transaction_array.ndim == 1:
|
||||
transaction_array = transaction_array.reshape(1, -1)
|
||||
|
||||
prediction = self.trainer.predict(transaction_array)
|
||||
probability = self.trainer.predict_proba(transaction_array)
|
||||
|
||||
fraud_prob = float(probability[0])
|
||||
normal_prob = float(1 - fraud_prob)
|
||||
|
||||
max_prob = max(fraud_prob, normal_prob)
|
||||
if max_prob > 0.8:
|
||||
confidence = ConfidenceLevel.HIGH
|
||||
elif max_prob > 0.6:
|
||||
confidence = ConfidenceLevel.MEDIUM
|
||||
else:
|
||||
confidence = ConfidenceLevel.LOW
|
||||
|
||||
class_name = TransactionClass.FRAUD if prediction[0] == 1 else TransactionClass.NORMAL
|
||||
|
||||
return EvaluationResult(
|
||||
predicted_class=int(prediction[0]),
|
||||
class_name=class_name,
|
||||
fraud_probability=fraud_prob,
|
||||
normal_probability=normal_prob,
|
||||
confidence=confidence
|
||||
)
|
||||
|
||||
def predict_batch(self, transactions: Union[List[List[float]], np.ndarray]) -> List[EvaluationResult]:
|
||||
if isinstance(transactions, list):
|
||||
transactions_array = np.array(transactions)
|
||||
else:
|
||||
transactions_array = transactions
|
||||
|
||||
predictions = self.trainer.predict(transactions_array)
|
||||
probabilities = self.trainer.predict_proba(transactions_array)
|
||||
|
||||
results = []
|
||||
for pred, prob in zip(predictions, probabilities):
|
||||
fraud_prob = float(prob)
|
||||
normal_prob = float(1 - fraud_prob)
|
||||
|
||||
max_prob = max(fraud_prob, normal_prob)
|
||||
if max_prob > 0.8:
|
||||
confidence = ConfidenceLevel.HIGH
|
||||
elif max_prob > 0.6:
|
||||
confidence = ConfidenceLevel.MEDIUM
|
||||
else:
|
||||
confidence = ConfidenceLevel.LOW
|
||||
|
||||
class_name = TransactionClass.FRAUD if pred == 1 else TransactionClass.NORMAL
|
||||
|
||||
results.append(EvaluationResult(
|
||||
predicted_class=int(pred),
|
||||
class_name=class_name,
|
||||
fraud_probability=fraud_prob,
|
||||
normal_probability=normal_prob,
|
||||
confidence=confidence
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_inference(model_dir: str = "models", model_name: str = "random_forest") -> FraudDetectionInference:
|
||||
return FraudDetectionInference(model_dir=model_dir, model_name=model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
inference = load_inference()
|
||||
|
||||
test_transaction = [
|
||||
0, -1.3598071336738, -0.0727811733098497, 2.53634673796914, 1.37815522427443,
|
||||
-0.338320769942518, 0.462387777762292, 0.239598554061257, 0.0986979012610507,
|
||||
0.363786969611213, 0.0907941719789316, -0.551599533260813, -0.617800855762348,
|
||||
-0.991389847235408, -0.311169353699879, 1.46817697209427, -0.470400525259478,
|
||||
0.207971241929242, 0.0257905801985591, 0.403992960255733, 0.251412098239705,
|
||||
-0.018306777944153, 0.277837575558899, -0.110473910188767, 0.0669280749146731,
|
||||
0.128539358273528, -0.189114843888824, 0.133558376740387, -0.0210530534538215,
|
||||
149.62
|
||||
]
|
||||
|
||||
result = inference.predict(test_transaction)
|
||||
print("预测结果:")
|
||||
print(f"类别: {result.class_name}")
|
||||
print(f"欺诈概率: {result.fraud_probability:.4f}")
|
||||
print(f"置信度: {result.confidence}")
|
||||
451
src/streamlit_app.py
Normal file
451
src/streamlit_app.py
Normal file
@ -0,0 +1,451 @@
|
||||
import streamlit as st
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
from pathlib import Path
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from agent_app import create_agent
|
||||
from features import TransactionFeatures, DecisionResult, TransactionClass, ConfidenceLevel, Priority
|
||||
|
||||
st.set_page_config(
|
||||
page_title="信用卡欺诈检测系统",
|
||||
page_icon="💳",
|
||||
layout="wide"
|
||||
)
|
||||
|
||||
st.title("💳 信用卡欺诈检测系统")
|
||||
st.markdown("基于机器学习的实时欺诈检测与决策支持系统")
|
||||
|
||||
@st.cache_resource
|
||||
def load_agent():
|
||||
return create_agent()
|
||||
|
||||
agent = load_agent()
|
||||
|
||||
@st.cache_data
|
||||
def load_csv_file(uploaded_file):
|
||||
if uploaded_file is not None:
|
||||
try:
|
||||
df = pl.read_csv(uploaded_file, schema_overrides={"Time": pl.Float64})
|
||||
return df
|
||||
except Exception as e:
|
||||
st.error(f"读取CSV文件失败: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
st.sidebar.header("系统信息")
|
||||
st.sidebar.info(f"""
|
||||
**模型信息**
|
||||
- 模型类型: RandomForest
|
||||
- 特征数量: 30
|
||||
- 支持工具: 2个
|
||||
- predict_fraud (ML工具)
|
||||
- analyze_transaction
|
||||
""")
|
||||
|
||||
st.header("输入交易数据")
|
||||
|
||||
input_mode = st.radio(
|
||||
"选择输入方式",
|
||||
["📁 上传CSV文件", "✏️ 手动输入特征"],
|
||||
horizontal=True
|
||||
)
|
||||
|
||||
if input_mode == "📁 上传CSV文件":
|
||||
st.subheader("上传CSV文件")
|
||||
|
||||
uploaded_file = st.file_uploader(
|
||||
"选择CSV文件",
|
||||
type=['csv'],
|
||||
help="上传包含交易数据的CSV文件"
|
||||
)
|
||||
|
||||
if uploaded_file is not None:
|
||||
df = load_csv_file(uploaded_file)
|
||||
|
||||
if df is not None:
|
||||
st.success(f"✅ 成功加载CSV文件,共 {df.height} 条交易记录")
|
||||
|
||||
st.write("### 数据预览")
|
||||
st.dataframe(df.head(10), use_container_width=True)
|
||||
|
||||
st.write("### 选择交易")
|
||||
if "Class" in df.columns:
|
||||
df = df.drop("Class")
|
||||
|
||||
row_index = st.number_input(
|
||||
"选择交易行号(从0开始)",
|
||||
min_value=0,
|
||||
max_value=df.height - 1,
|
||||
value=0,
|
||||
step=1
|
||||
)
|
||||
|
||||
if st.button("📋 加载选中的交易", type="primary"):
|
||||
feature_names = [
|
||||
'Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9',
|
||||
'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19',
|
||||
'V20', 'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28', 'Amount'
|
||||
]
|
||||
transaction = [df[row_index, col] for col in feature_names]
|
||||
|
||||
st.session_state.transaction = transaction
|
||||
st.success(f"✅ 已加载第 {row_index} 行的交易数据")
|
||||
|
||||
st.write("### 选中的交易数据")
|
||||
feature_data = {
|
||||
"特征名称": feature_names,
|
||||
"特征值": [f"{v:.6f}" for v in transaction]
|
||||
}
|
||||
st.dataframe(
|
||||
pl.DataFrame(feature_data),
|
||||
use_container_width=True
|
||||
)
|
||||
|
||||
else:
|
||||
st.subheader("手动输入特征")
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
st.write("基础信息")
|
||||
time = st.number_input("Time (交易时间)", value=0.0, step=1.0)
|
||||
amount = st.number_input("Amount (交易金额)", value=100.0, step=1.0)
|
||||
|
||||
with col2:
|
||||
st.write("PCA特征 V1-V14")
|
||||
v1 = st.number_input("V1", value=0.0, step=0.1)
|
||||
v2 = st.number_input("V2", value=0.0, step=0.1)
|
||||
v3 = st.number_input("V3", value=0.0, step=0.1)
|
||||
v4 = st.number_input("V4", value=0.0, step=0.1)
|
||||
v5 = st.number_input("V5", value=0.0, step=0.1)
|
||||
v6 = st.number_input("V6", value=0.0, step=0.1)
|
||||
v7 = st.number_input("V7", value=0.0, step=0.1)
|
||||
v8 = st.number_input("V8", value=0.0, step=0.1)
|
||||
v9 = st.number_input("V9", value=0.0, step=0.1)
|
||||
v10 = st.number_input("V10", value=0.0, step=0.1)
|
||||
v11 = st.number_input("V11", value=0.0, step=0.1)
|
||||
v12 = st.number_input("V12", value=0.0, step=0.1)
|
||||
v13 = st.number_input("V13", value=0.0, step=0.1)
|
||||
v14 = st.number_input("V14", value=0.0, step=0.1)
|
||||
|
||||
col3, col4 = st.columns(2)
|
||||
|
||||
with col3:
|
||||
st.write("PCA特征 V15-V21")
|
||||
v15 = st.number_input("V15", value=0.0, step=0.1)
|
||||
v16 = st.number_input("V16", value=0.0, step=0.1)
|
||||
v17 = st.number_input("V17", value=0.0, step=0.1)
|
||||
v18 = st.number_input("V18", value=0.0, step=0.1)
|
||||
v19 = st.number_input("V19", value=0.0, step=0.1)
|
||||
v20 = st.number_input("V20", value=0.0, step=0.1)
|
||||
v21 = st.number_input("V21", value=0.0, step=0.1)
|
||||
|
||||
with col4:
|
||||
st.write("PCA特征 V22-V28")
|
||||
v22 = st.number_input("V22", value=0.0, step=0.1)
|
||||
v23 = st.number_input("V23", value=0.0, step=0.1)
|
||||
v24 = st.number_input("V24", value=0.0, step=0.1)
|
||||
v25 = st.number_input("V25", value=0.0, step=0.1)
|
||||
v26 = st.number_input("V26", value=0.0, step=0.1)
|
||||
v27 = st.number_input("V27", value=0.0, step=0.1)
|
||||
v28 = st.number_input("V28", value=0.0, step=0.1)
|
||||
|
||||
st.divider()
|
||||
|
||||
if st.button("🔍 检测欺诈", type="primary", use_container_width=True):
|
||||
if input_mode == "📁 上传CSV文件":
|
||||
if "transaction" in st.session_state:
|
||||
transaction = st.session_state.transaction
|
||||
else:
|
||||
st.warning("⚠️ 请先上传CSV文件并选择交易")
|
||||
st.stop()
|
||||
else:
|
||||
transaction = [
|
||||
time, v1, v2, v3, v4, v5, v6, v7, v8, v9,
|
||||
v10, v11, v12, v13, v14, v15, v16, v17, v18, v19,
|
||||
v20, v21, v22, v23, v24, v25, v26, v27, v28, amount
|
||||
]
|
||||
|
||||
with st.spinner("正在分析交易..."):
|
||||
result = agent.process_transaction(transaction)
|
||||
|
||||
st.success("分析完成!")
|
||||
|
||||
col5, col6, col7 = st.columns(3)
|
||||
|
||||
with col5:
|
||||
st.metric(
|
||||
label="预测类别",
|
||||
value=result.evaluation.class_name.value,
|
||||
delta=f"置信度: {result.evaluation.confidence.value}"
|
||||
)
|
||||
|
||||
with col6:
|
||||
fraud_prob = result.evaluation.fraud_probability * 100
|
||||
st.metric(
|
||||
label="欺诈概率",
|
||||
value=f"{fraud_prob:.2f}%",
|
||||
delta=f"{100 - fraud_prob:.2f}% 正常"
|
||||
)
|
||||
|
||||
with col7:
|
||||
st.metric(
|
||||
label="行动建议数量",
|
||||
value=len(result.action_plan.actions),
|
||||
delta="已生成"
|
||||
)
|
||||
|
||||
st.divider()
|
||||
|
||||
tab1, tab2, tab3 = st.tabs(["📊 评估结果", "🔍 特征解释", "📋 行动计划"])
|
||||
|
||||
with tab1:
|
||||
st.subheader("模型评估结果")
|
||||
|
||||
eval_col1, eval_col2 = st.columns(2)
|
||||
|
||||
with eval_col1:
|
||||
st.info(f"""
|
||||
**预测信息**
|
||||
- 预测类别: {result.evaluation.class_name.value}
|
||||
- 预测标签: {result.evaluation.predicted_class}
|
||||
- 置信度: {result.evaluation.confidence.value}
|
||||
""")
|
||||
|
||||
with eval_col2:
|
||||
st.info(f"""
|
||||
**概率分布**
|
||||
- 欺诈概率: {result.evaluation.fraud_probability:.4f}
|
||||
- 正常概率: {result.evaluation.normal_probability:.4f}
|
||||
""")
|
||||
|
||||
if result.evaluation.class_name == TransactionClass.FRAUD:
|
||||
st.error(f"⚠️ 该交易被识别为**欺诈交易**,请立即采取行动!")
|
||||
else:
|
||||
st.success(f"✅ 该交易被识别为**正常交易**")
|
||||
|
||||
with tab2:
|
||||
st.subheader("🔍 特征解释")
|
||||
|
||||
st.markdown("""
|
||||
<div style="background-color: #e3f2fd; padding: 15px; border-radius: 10px; margin-bottom: 20px;">
|
||||
<h4 style="margin: 0; color: #1565c0;">💡 什么是特征解释?</h4>
|
||||
<p style="margin: 10px 0 0 0; color: #424242;">
|
||||
就像医生看病时会检查各项指标一样,我们的AI模型也通过分析交易的各项"特征"来判断是否为欺诈。
|
||||
下面这些特征是影响判断结果最重要的因素,让我们来看看它们是如何"告诉"模型这个交易是否有问题的。
|
||||
</p>
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
st.info(f"""
|
||||
**使用的模型**: {result.explanation.model_type}
|
||||
|
||||
**整体判断依据**: {result.explanation.overall_explanation}
|
||||
""")
|
||||
|
||||
st.write("### 📊 关键影响因素分析")
|
||||
st.caption("这些特征对判断结果影响最大,就像破案时的关键线索")
|
||||
|
||||
feature_descriptions = {
|
||||
"Time": "交易发生的时间距离第一次交易经过的秒数。欺诈交易往往在特定时间段更频繁,比如深夜或节假日。",
|
||||
"Amount": "交易金额。异常高额或异常低额的交易都可能引起怀疑,特别是与用户历史消费习惯不符时。",
|
||||
"V1": "经过PCA(主成分分析)转换后的特征1,代表交易数据的某种模式。PCA将原始数据转换成更易分析的形式。",
|
||||
"V2": "经过PCA转换后的特征2,捕捉交易数据的另一种模式。",
|
||||
"V3": "经过PCA转换后的特征3,反映交易数据的特定维度。",
|
||||
"V4": "经过PCA转换后的特征4,可能代表交易频率或模式。",
|
||||
"V5": "经过PCA转换后的特征5,可能涉及交易的时间或空间特征。",
|
||||
"V6": "经过PCA转换后的特征6,可能反映交易的某种统计特性。",
|
||||
"V7": "经过PCA转换后的特征7,可能代表交易的异常程度。",
|
||||
"V8": "经过PCA转换后的特征8,可能涉及交易的上下文信息。",
|
||||
"V9": "经过PCA转换后的特征9,可能反映交易的时间序列特征。",
|
||||
"V10": "经过PCA转换后的特征10,可能代表交易的某种模式。",
|
||||
"V11": "经过PCA转换后的特征11,可能涉及交易的频率特征。",
|
||||
"V12": "经过PCA转换后的特征12,可能反映交易的异常模式。",
|
||||
"V13": "经过PCA转换后的特征13,可能代表交易的某种统计特性。",
|
||||
"V14": "经过PCA转换后的特征14,可能涉及交易的时间特征。",
|
||||
"V15": "经过PCA转换后的特征15,可能反映交易的某种模式。",
|
||||
"V16": "经过PCA转换后的特征16,可能代表交易的异常程度。",
|
||||
"V17": "经过PCA转换后的特征17,可能涉及交易的上下文信息。",
|
||||
"V18": "经过PCA转换后的特征18,可能反映交易的时间序列特征。",
|
||||
"V19": "经过PCA转换后的特征19,可能代表交易的某种模式。",
|
||||
"V20": "经过PCA转换后的特征20,可能涉及交易的频率特征。",
|
||||
"V21": "经过PCA转换后的特征21,可能反映交易的异常模式。",
|
||||
"V22": "经过PCA转换后的特征22,可能代表交易的某种统计特性。",
|
||||
"V23": "经过PCA转换后的特征23,可能涉及交易的时间特征。",
|
||||
"V24": "经过PCA转换后的特征24,可能反映交易的某种模式。",
|
||||
"V25": "经过PCA转换后的特征25,可能代表交易的异常程度。",
|
||||
"V26": "经过PCA转换后的特征26,可能涉及交易的上下文信息。",
|
||||
"V27": "经过PCA转换后的特征27,可能反映交易的时间序列特征。",
|
||||
"V28": "经过PCA转换后的特征28,可能代表交易的某种模式。"
|
||||
}
|
||||
|
||||
for i, feature in enumerate(result.explanation.key_features, 1):
|
||||
importance_percent = min(feature.importance * 100, 100)
|
||||
|
||||
impact_emoji = "📈" if feature.impact == "正向影响" else "📉"
|
||||
impact_color = "#ff5252" if feature.impact == "正向影响" else "#4caf50"
|
||||
|
||||
feature_desc = feature_descriptions.get(feature.feature_name, "该特征是经过PCA转换后的数据特征,用于帮助模型识别交易模式。")
|
||||
|
||||
with st.expander(f"{i}. {feature.feature_name} {impact_emoji}"):
|
||||
st.markdown(f"""
|
||||
<div style="padding: 15px; border-radius: 8px; background-color: #f5f5f5;">
|
||||
|
||||
<p style="margin: 10px 0; color: #616161;">
|
||||
<strong>影响程度:</strong>
|
||||
</p>
|
||||
<div style="background-color: #e0e0e0; border-radius: 5px; height: 20px; margin: 5px 0;">
|
||||
<div style="background-color: {impact_color}; height: 100%; width: {importance_percent}%; border-radius: 5px;"></div>
|
||||
</div>
|
||||
<p style="margin: 5px 0; font-size: 14px; color: {impact_color}; font-weight: bold;">
|
||||
{importance_percent:.1f}% 的影响力
|
||||
</p>
|
||||
|
||||
<p style="margin: 15px 0 10px 0; color: #616161;">
|
||||
<strong>影响方向:</strong> {feature.impact}
|
||||
</p>
|
||||
|
||||
<div style="padding: 12px; background-color: #e3f2fd; border-left: 4px solid #2196f3; border-radius: 4px; margin-bottom: 10px;">
|
||||
<p style="margin: 0; color: #1565c0;">
|
||||
<strong>📖 这个特征是什么?</strong><br>
|
||||
{feature_desc}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div style="padding: 10px; background-color: #fff3e0; border-left: 4px solid #ff9800; border-radius: 4px;">
|
||||
<p style="margin: 0; color: #e65100;">
|
||||
<strong>💡 简单来说:</strong>
|
||||
{"这个特征让模型更倾向于认为这是欺诈交易" if feature.impact == "正向影响" else "这个特征让模型更倾向于认为这是正常交易"}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
with tab3:
|
||||
st.subheader("📋 行动计划")
|
||||
|
||||
st.markdown("""
|
||||
<div style="background-color: #e8f5e9; padding: 15px; border-radius: 10px; margin-bottom: 20px;">
|
||||
<h4 style="margin: 0; color: #2e7d32;">🎯 为什么需要行动计划?</h4>
|
||||
<p style="margin: 10px 0 0 0; color: #424242;">
|
||||
根据检测结果,我们为您准备了具体的行动建议。这些建议按照紧急程度排序,
|
||||
帮助您快速、有效地应对可能的风险。请根据实际情况选择合适的处理方式。
|
||||
</p>
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
st.write("### 🚀 建议采取的行动")
|
||||
st.caption("按优先级排序,从高到低依次处理")
|
||||
|
||||
for action in result.action_plan.actions:
|
||||
priority_info = {
|
||||
Priority.URGENT: {
|
||||
"emoji": "🔴",
|
||||
"color": "#d32f2f",
|
||||
"bg_color": "#ffcdd2",
|
||||
"description": "紧急 - 需要立即处理,不能拖延"
|
||||
},
|
||||
Priority.HIGH: {
|
||||
"emoji": "🟠",
|
||||
"color": "#f57c00",
|
||||
"bg_color": "#ffe0b2",
|
||||
"description": "高优先级 - 尽快处理"
|
||||
},
|
||||
Priority.MEDIUM: {
|
||||
"emoji": "🟡",
|
||||
"color": "#fbc02d",
|
||||
"bg_color": "#fff9c4",
|
||||
"description": "中等优先级 - 适时处理"
|
||||
},
|
||||
Priority.LOW: {
|
||||
"emoji": "🟢",
|
||||
"color": "#388e3c",
|
||||
"bg_color": "#c8e6c9",
|
||||
"description": "低优先级 - 可以稍后处理"
|
||||
},
|
||||
Priority.ROUTINE: {
|
||||
"emoji": "⚪",
|
||||
"color": "#757575",
|
||||
"bg_color": "#e0e0e0",
|
||||
"description": "常规 - 按正常流程处理"
|
||||
}
|
||||
}.get(action.priority, {
|
||||
"emoji": "⚪",
|
||||
"color": "#757575",
|
||||
"bg_color": "#e0e0e0",
|
||||
"description": "常规"
|
||||
})
|
||||
|
||||
with st.container():
|
||||
st.markdown(f"""
|
||||
<div style="padding: 20px; border-radius: 10px; background-color: {priority_info['bg_color']}; margin-bottom: 15px; border-left: 5px solid {priority_info['color']};">
|
||||
<div style="display: flex; align-items: center; margin-bottom: 10px;">
|
||||
<span style="font-size: 24px; margin-right: 10px;">{priority_info['emoji']}</span>
|
||||
<div>
|
||||
<h4 style="margin: 0; color: {priority_info['color']};">{action.action}</h4>
|
||||
<p style="margin: 5px 0 0 0; font-size: 14px; color: #616161;">
|
||||
<strong>优先级:</strong> {action.priority.value} - {priority_info['description']}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div style="background-color: white; padding: 12px; border-radius: 6px; margin-top: 10px;">
|
||||
<p style="margin: 0; color: #424242;">
|
||||
<strong>💡 为什么要这样做?</strong><br>
|
||||
{action.reason}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
st.divider()
|
||||
|
||||
st.header("📝 使用说明")
|
||||
|
||||
with st.expander("查看使用说明"):
|
||||
st.markdown("""
|
||||
### 如何使用本系统
|
||||
|
||||
#### 方式1:上传CSV文件
|
||||
1. **上传文件**: 点击"选择CSV文件"按钮上传包含交易数据的CSV文件
|
||||
2. **查看数据**: 系统会显示数据预览
|
||||
3. **选择交易**: 输入行号选择要分析的交易
|
||||
4. **加载数据**: 点击"加载选中的交易"按钮
|
||||
5. **开始检测**: 点击"检测欺诈"按钮开始分析
|
||||
|
||||
#### 方式2:手动输入
|
||||
1. **输入特征**: 在表单中输入30个特征值
|
||||
- Time: 交易发生时间(秒)
|
||||
- V1-V28: PCA转换后的特征
|
||||
- Amount: 交易金额
|
||||
|
||||
2. **点击检测**: 点击"检测欺诈"按钮开始分析
|
||||
|
||||
3. **查看结果**: 系统会返回三个部分的结果
|
||||
- **评估结果**: 模型的预测类别和概率
|
||||
- **特征解释**: 影响预测的关键特征
|
||||
- **行动计划**: 建议的处理措施
|
||||
|
||||
### CSV文件格式要求
|
||||
|
||||
CSV文件必须包含以下列:
|
||||
- Time, V1, V2, V3, V4, V5, V6, V7, V8, V9
|
||||
- V10, V11, V12, V13, V14, V15, V16, V17, V18, V19
|
||||
- V20, V21, V22, V23, V24, V25, V26, V27, V28, Amount
|
||||
- Class (可选,如果存在会被自动删除)
|
||||
|
||||
### 系统特点
|
||||
|
||||
- ✅ 使用随机森林模型进行预测
|
||||
- ✅ 支持CSV文件批量处理
|
||||
- ✅ 提供特征重要性分析
|
||||
- ✅ 根据置信度生成行动建议
|
||||
- ✅ 实时分析,快速响应
|
||||
|
||||
### 注意事项
|
||||
|
||||
- 本系统仅供演示使用
|
||||
- 实际应用中需要结合人工审核
|
||||
- 建议定期更新模型以保持准确性
|
||||
""")
|
||||
226
src/train.py
Normal file
226
src/train.py
Normal file
@ -0,0 +1,226 @@
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.metrics import (
|
||||
precision_score, recall_score, f1_score, accuracy_score,
|
||||
precision_recall_curve, auc, confusion_matrix
|
||||
)
|
||||
from imblearn.over_sampling import SMOTE
|
||||
import numpy as np
|
||||
import logging
|
||||
import joblib
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple, Optional
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from features import ModelMetrics, TrainingResult
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreditCardFraudModelTrainer:
|
||||
def __init__(self, model_dir: str = "models"):
|
||||
self.model_dir = Path(model_dir)
|
||||
self.model_dir.mkdir(exist_ok=True)
|
||||
|
||||
self.models = {
|
||||
"logistic_regression": LogisticRegression(
|
||||
random_state=42,
|
||||
class_weight="balanced",
|
||||
max_iter=1000
|
||||
),
|
||||
"random_forest": RandomForestClassifier(
|
||||
random_state=42,
|
||||
class_weight="balanced",
|
||||
n_estimators=100
|
||||
)
|
||||
}
|
||||
|
||||
self.scaler = StandardScaler()
|
||||
self.best_model = None
|
||||
self.best_model_name = None
|
||||
self.best_model_score = 0
|
||||
self.evaluation_results = {}
|
||||
|
||||
def train(self, X_train: np.ndarray, y_train: np.ndarray, use_smote: bool = False) -> Dict[str, any]:
|
||||
logger.info("开始训练模型...")
|
||||
|
||||
X_train_scaled = self.scaler.fit_transform(X_train)
|
||||
|
||||
if use_smote:
|
||||
logger.info("使用SMOTE处理不平衡数据...")
|
||||
smote = SMOTE(random_state=42)
|
||||
X_train_scaled, y_train = smote.fit_resample(X_train_scaled, y_train)
|
||||
logger.info(f"SMOTE处理后,训练集形状: X={X_train_scaled.shape}, y={y_train.shape}")
|
||||
|
||||
for model_name, model in self.models.items():
|
||||
logger.info(f"训练模型: {model_name}")
|
||||
model.fit(X_train_scaled, y_train)
|
||||
|
||||
model_path = self.model_dir / f"{model_name}_model.joblib"
|
||||
joblib.dump(model, model_path)
|
||||
logger.info(f"模型已保存: {model_path}")
|
||||
|
||||
scaler_path = self.model_dir / "scaler.joblib"
|
||||
joblib.dump(self.scaler, scaler_path)
|
||||
logger.info(f"特征缩放器已保存: {scaler_path}")
|
||||
|
||||
logger.info("所有模型训练完成")
|
||||
return {"status": "success", "message": "模型训练完成"}
|
||||
|
||||
def evaluate(self, X_test: np.ndarray, y_test: np.ndarray) -> Dict[str, Dict[str, any]]:
|
||||
logger.info("开始评估模型...")
|
||||
|
||||
X_test_scaled = self.scaler.transform(X_test)
|
||||
|
||||
for model_name, model in self.models.items():
|
||||
logger.info(f"评估模型: {model_name}")
|
||||
|
||||
y_pred = model.predict(X_test_scaled)
|
||||
y_pred_proba = model.predict_proba(X_test_scaled)[:, 1]
|
||||
|
||||
accuracy = accuracy_score(y_test, y_pred)
|
||||
precision = precision_score(y_test, y_pred)
|
||||
recall = recall_score(y_test, y_pred)
|
||||
f1 = f1_score(y_test, y_pred)
|
||||
|
||||
precision_curve, recall_curve, _ = precision_recall_curve(y_test, y_pred_proba)
|
||||
pr_auc = auc(recall_curve, precision_curve)
|
||||
|
||||
cm = confusion_matrix(y_test, y_pred)
|
||||
|
||||
self.evaluation_results[model_name] = {
|
||||
"accuracy": accuracy,
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
"f1_score": f1,
|
||||
"pr_auc": pr_auc,
|
||||
"confusion_matrix": cm,
|
||||
"y_pred": y_pred,
|
||||
"y_pred_proba": y_pred_proba
|
||||
}
|
||||
|
||||
if pr_auc > self.best_model_score:
|
||||
self.best_model_score = pr_auc
|
||||
self.best_model = model
|
||||
self.best_model_name = model_name
|
||||
|
||||
logger.info(f"模型 {model_name} 评估完成, PR-AUC: {pr_auc:.4f}, F1: {f1:.4f}")
|
||||
|
||||
logger.info(f"最佳模型: {self.best_model_name}, PR-AUC: {self.best_model_score:.4f}")
|
||||
return self.evaluation_results
|
||||
|
||||
def get_best_model(self) -> Tuple[Optional[str], Optional[any], float]:
|
||||
if self.best_model is None:
|
||||
logger.warning("尚未训练或评估模型")
|
||||
return None, None, 0
|
||||
return self.best_model_name, self.best_model, self.best_model_score
|
||||
|
||||
def predict(self, X: np.ndarray, model_name: Optional[str] = None) -> np.ndarray:
|
||||
X_scaled = self.scaler.transform(X)
|
||||
|
||||
if model_name is not None:
|
||||
if model_name not in self.models:
|
||||
logger.error(f"模型 {model_name} 未找到")
|
||||
raise ValueError(f"模型 {model_name} 未找到")
|
||||
model = self.models[model_name]
|
||||
else:
|
||||
if self.best_model is None:
|
||||
logger.error("尚未训练或评估模型")
|
||||
raise RuntimeError("尚未训练或评估模型,无法进行预测")
|
||||
model = self.best_model
|
||||
model_name = self.best_model_name
|
||||
|
||||
logger.info(f"使用模型 {model_name} 进行预测")
|
||||
return model.predict(X_scaled)
|
||||
|
||||
def predict_proba(self, X: np.ndarray, model_name: Optional[str] = None) -> np.ndarray:
|
||||
X_scaled = self.scaler.transform(X)
|
||||
|
||||
if model_name is not None:
|
||||
if model_name not in self.models:
|
||||
logger.error(f"模型 {model_name} 未找到")
|
||||
return None
|
||||
model = self.models[model_name]
|
||||
else:
|
||||
if self.best_model is None:
|
||||
logger.error("尚未训练或评估模型")
|
||||
return None
|
||||
model = self.best_model
|
||||
model_name = self.best_model_name
|
||||
|
||||
logger.info(f"使用模型 {model_name} 进行概率预测")
|
||||
return model.predict_proba(X_scaled)[:, 1]
|
||||
|
||||
def load_model(self, model_name: str) -> bool:
|
||||
try:
|
||||
model_path = self.model_dir / f"{model_name}_model.joblib"
|
||||
model = joblib.load(model_path)
|
||||
self.models[model_name] = model
|
||||
|
||||
if self.best_model is None:
|
||||
self.best_model = model
|
||||
self.best_model_name = model_name
|
||||
logger.info(f"设置 {model_name} 为默认最佳模型")
|
||||
|
||||
logger.info(f"模型加载成功: {model_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"加载模型失败: {e}")
|
||||
return False
|
||||
|
||||
def load_scaler(self) -> bool:
|
||||
try:
|
||||
scaler_path = self.model_dir / "scaler.joblib"
|
||||
self.scaler = joblib.load(scaler_path)
|
||||
logger.info(f"特征缩放器加载成功: {scaler_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"加载特征缩放器失败: {e}")
|
||||
return False
|
||||
|
||||
def print_evaluation_results(self) -> None:
|
||||
print("\n=== 模型评估结果 ===")
|
||||
for model_name, results in self.evaluation_results.items():
|
||||
print(f"\n模型: {model_name}")
|
||||
print("-" * 30)
|
||||
print(f"准确率 (Accuracy): {results['accuracy']:.4f}")
|
||||
print(f"精确率 (Precision): {results['precision']:.4f}")
|
||||
print(f"召回率 (Recall): {results['recall']:.4f}")
|
||||
print(f"F1分数 (F1-Score): {results['f1_score']:.4f}")
|
||||
print(f"PR-AUC: {results['pr_auc']:.4f}")
|
||||
print("\n混淆矩阵:")
|
||||
print(results['confusion_matrix'])
|
||||
|
||||
print(f"\n最佳模型: {self.best_model_name}")
|
||||
print(f"最佳模型PR-AUC: {self.best_model_score:.4f}")
|
||||
|
||||
|
||||
def train_and_evaluate(data_path: str = "data/creditcard.csv", use_smote: bool = False) -> CreditCardFraudModelTrainer:
|
||||
from data import load_data
|
||||
|
||||
logger.info("=== 信用卡欺诈检测系统开始运行 ===")
|
||||
|
||||
processor = load_data(data_path)
|
||||
X_train = processor.train_features
|
||||
y_train = processor.train_labels
|
||||
X_test = processor.test_features
|
||||
y_test = processor.test_labels
|
||||
|
||||
logger.info(f"\n训练集: {X_train.shape}, {y_train.shape}")
|
||||
logger.info(f"测试集: {X_test.shape}, {y_test.shape}")
|
||||
|
||||
trainer = CreditCardFraudModelTrainer()
|
||||
train_result = trainer.train(X_train, y_train, use_smote=use_smote)
|
||||
logger.info(train_result["message"])
|
||||
|
||||
evaluation_results = trainer.evaluate(X_test, y_test)
|
||||
trainer.print_evaluation_results()
|
||||
|
||||
logger.info("\n=== 信用卡欺诈检测系统运行完成 ===")
|
||||
return trainer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_and_evaluate()
|
||||
64
tests/test_agent.py
Normal file
64
tests/test_agent.py
Normal file
@ -0,0 +1,64 @@
|
||||
import pytest
|
||||
from src.agent_app import CreditCardFraudAgent, create_agent, Tool
|
||||
|
||||
|
||||
def test_agent_initialization():
|
||||
agent = CreditCardFraudAgent(model_dir="models", model_name="random_forest")
|
||||
assert agent.inference is not None
|
||||
assert len(agent.tools) == 2
|
||||
|
||||
|
||||
def test_create_agent():
|
||||
agent = create_agent(model_dir="models", model_name="random_forest")
|
||||
assert isinstance(agent, CreditCardFraudAgent)
|
||||
|
||||
|
||||
def test_list_tools():
|
||||
agent = create_agent()
|
||||
tools = agent.list_tools()
|
||||
assert len(tools) == 2
|
||||
tool_names = [tool["name"] for tool in tools]
|
||||
assert "predict_fraud" in tool_names
|
||||
assert "analyze_transaction" in tool_names
|
||||
|
||||
|
||||
def test_tool_structure():
|
||||
agent = create_agent()
|
||||
for tool in agent.tools:
|
||||
assert hasattr(tool, "name")
|
||||
assert hasattr(tool, "description")
|
||||
assert hasattr(tool, "func")
|
||||
assert callable(tool.func)
|
||||
|
||||
|
||||
def test_analyze_transaction():
|
||||
agent = create_agent()
|
||||
transaction = [
|
||||
0, -1.36, -0.07, 2.54, 1.38, -0.34, 0.46, 0.24, 0.10, 0.36,
|
||||
0.09, -0.55, -0.62, -0.99, -0.31, 1.47, -0.47, 0.21, 0.03, 0.40,
|
||||
0.25, -0.02, 0.28, -0.11, 0.07, 0.13, -0.19, 0.13, -0.02, 149.62
|
||||
]
|
||||
analysis = agent._analyze_transaction(transaction)
|
||||
assert "feature_count" in analysis
|
||||
assert "amount" in analysis
|
||||
assert "time" in analysis
|
||||
assert "statistics" in analysis
|
||||
assert "anomalies" in analysis
|
||||
assert analysis["feature_count"] == 30
|
||||
assert analysis["amount"] == 149.62
|
||||
|
||||
|
||||
def test_process_transaction():
|
||||
agent = create_agent()
|
||||
transaction = [
|
||||
0, -1.36, -0.07, 2.54, 1.38, -0.34, 0.46, 0.24, 0.10, 0.36,
|
||||
0.09, -0.55, -0.62, -0.99, -0.31, 1.47, -0.47, 0.21, 0.03, 0.40,
|
||||
0.25, -0.02, 0.28, -0.11, 0.07, 0.13, -0.19, 0.13, -0.02, 149.62
|
||||
]
|
||||
result = agent.process_transaction(transaction)
|
||||
assert result.evaluation is not None
|
||||
assert result.explanation is not None
|
||||
assert result.action_plan is not None
|
||||
assert hasattr(result.evaluation, "predicted_class")
|
||||
assert hasattr(result.evaluation, "fraud_probability")
|
||||
assert hasattr(result.action_plan, "actions")
|
||||
62
tests/test_data.py
Normal file
62
tests/test_data.py
Normal file
@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from src.data import CreditCardDataProcessor, load_data
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_data_processor_initialization():
|
||||
processor = CreditCardDataProcessor("data/creditcard.csv")
|
||||
assert processor.file_path == "data/creditcard.csv"
|
||||
assert processor.data is None
|
||||
|
||||
|
||||
def test_load_data():
|
||||
processor = CreditCardDataProcessor("data/creditcard.csv")
|
||||
processor.load_data()
|
||||
assert processor.data is not None
|
||||
assert processor.data.height > 0
|
||||
assert "Class" in processor.data.columns
|
||||
|
||||
|
||||
def test_validate_data():
|
||||
processor = CreditCardDataProcessor("data/creditcard.csv")
|
||||
processor.load_data()
|
||||
processor.validate_data()
|
||||
assert processor.data is not None
|
||||
|
||||
|
||||
def test_split_data_by_time():
|
||||
processor = CreditCardDataProcessor("data/creditcard.csv")
|
||||
processor.load_data()
|
||||
train_data, test_data = processor.split_data_by_time(test_ratio=0.2)
|
||||
assert train_data is not None
|
||||
assert test_data is not None
|
||||
assert train_data.height > test_data.height
|
||||
|
||||
|
||||
def test_prepare_features_labels():
|
||||
processor = CreditCardDataProcessor("data/creditcard.csv")
|
||||
processor.load_data()
|
||||
processor.split_data_by_time()
|
||||
processor.prepare_features_labels()
|
||||
assert processor.train_features is not None
|
||||
assert processor.train_labels is not None
|
||||
assert processor.test_features is not None
|
||||
assert processor.test_labels is not None
|
||||
|
||||
|
||||
def test_load_data_function():
|
||||
processor = load_data("data/creditcard.csv")
|
||||
assert processor.data is not None
|
||||
assert processor.train_features is not None
|
||||
assert processor.test_features is not None
|
||||
|
||||
|
||||
def test_get_statistics():
|
||||
processor = load_data("data/creditcard.csv")
|
||||
stats = processor.get_statistics()
|
||||
assert "总记录数" in stats
|
||||
assert "特征数" in stats
|
||||
assert "欺诈交易数" in stats
|
||||
assert "非欺诈交易数" in stats
|
||||
assert stats["总记录数"] > 0
|
||||
137
tests/test_features.py
Normal file
137
tests/test_features.py
Normal file
@ -0,0 +1,137 @@
|
||||
import pytest
|
||||
from src.features import (
|
||||
TransactionFeatures, EvaluationResult, ExplanationResult,
|
||||
ActionPlan, DecisionResult, ModelMetrics, TrainingResult,
|
||||
TransactionClass, ConfidenceLevel, Priority,
|
||||
FeatureContribution, Action
|
||||
)
|
||||
|
||||
|
||||
def test_transaction_features():
|
||||
features = TransactionFeatures(
|
||||
time=0.0,
|
||||
v1=-1.36, v2=-0.07, v3=2.54, v4=1.38, v5=-0.34,
|
||||
v6=0.46, v7=0.24, v8=0.10, v9=0.36, v10=0.09,
|
||||
v11=-0.55, v12=-0.62, v13=-0.99, v14=-0.31, v15=1.47,
|
||||
v16=-0.47, v17=0.21, v18=0.03, v19=0.40, v20=0.25,
|
||||
v21=-0.02, v22=0.28, v23=-0.11, v24=0.07, v25=0.13,
|
||||
v26=-0.19, v27=0.13, v28=-0.02, amount=149.62
|
||||
)
|
||||
assert features.time == 0.0
|
||||
assert features.amount == 149.62
|
||||
assert len(features.to_array()) == 30
|
||||
|
||||
|
||||
def test_evaluation_result():
|
||||
result = EvaluationResult(
|
||||
predicted_class=1,
|
||||
class_name=TransactionClass.FRAUD,
|
||||
fraud_probability=0.95,
|
||||
normal_probability=0.05,
|
||||
confidence=ConfidenceLevel.HIGH
|
||||
)
|
||||
assert result.predicted_class == 1
|
||||
assert result.class_name == TransactionClass.FRAUD
|
||||
assert result.fraud_probability == 0.95
|
||||
assert result.confidence == ConfidenceLevel.HIGH
|
||||
|
||||
|
||||
def test_feature_contribution():
|
||||
contribution = FeatureContribution(
|
||||
feature_name="V14",
|
||||
value=-0.99,
|
||||
importance=0.15,
|
||||
contribution=-0.15,
|
||||
impact="负"
|
||||
)
|
||||
assert contribution.feature_name == "V14"
|
||||
assert contribution.value == -0.99
|
||||
assert contribution.importance == 0.15
|
||||
assert contribution.impact == "负"
|
||||
|
||||
|
||||
def test_explanation_result():
|
||||
explanation = ExplanationResult(
|
||||
model_type="RandomForestClassifier",
|
||||
predicted_class=TransactionClass.FRAUD,
|
||||
key_features=[],
|
||||
overall_explanation="测试解释"
|
||||
)
|
||||
assert explanation.model_type == "RandomForestClassifier"
|
||||
assert explanation.predicted_class == TransactionClass.FRAUD
|
||||
|
||||
|
||||
def test_action():
|
||||
action = Action(
|
||||
priority=Priority.URGENT,
|
||||
action="冻结账户",
|
||||
reason="检测到欺诈"
|
||||
)
|
||||
assert action.priority == Priority.URGENT
|
||||
assert action.action == "冻结账户"
|
||||
|
||||
|
||||
def test_action_plan():
|
||||
plan = ActionPlan(
|
||||
predicted_class=TransactionClass.FRAUD,
|
||||
confidence=ConfidenceLevel.HIGH,
|
||||
actions=[]
|
||||
)
|
||||
assert plan.predicted_class == TransactionClass.FRAUD
|
||||
assert plan.confidence == ConfidenceLevel.HIGH
|
||||
|
||||
|
||||
def test_decision_result():
|
||||
result = DecisionResult(
|
||||
evaluation=EvaluationResult(
|
||||
predicted_class=1,
|
||||
class_name=TransactionClass.FRAUD,
|
||||
fraud_probability=0.95,
|
||||
normal_probability=0.05,
|
||||
confidence=ConfidenceLevel.HIGH
|
||||
),
|
||||
explanation=ExplanationResult(
|
||||
model_type="RandomForestClassifier",
|
||||
predicted_class=TransactionClass.FRAUD,
|
||||
key_features=[],
|
||||
overall_explanation="测试"
|
||||
),
|
||||
action_plan=ActionPlan(
|
||||
predicted_class=TransactionClass.FRAUD,
|
||||
confidence=ConfidenceLevel.HIGH,
|
||||
actions=[]
|
||||
),
|
||||
timestamp="2026-01-15"
|
||||
)
|
||||
assert result.evaluation.predicted_class == 1
|
||||
assert result.explanation.model_type == "RandomForestClassifier"
|
||||
assert result.action_plan.predicted_class == TransactionClass.FRAUD
|
||||
|
||||
|
||||
def test_model_metrics():
|
||||
metrics = ModelMetrics(
|
||||
accuracy=0.95,
|
||||
precision=0.90,
|
||||
recall=0.85,
|
||||
f1_score=0.87,
|
||||
pr_auc=0.92
|
||||
)
|
||||
assert metrics.accuracy == 0.95
|
||||
assert metrics.precision == 0.90
|
||||
assert metrics.pr_auc == 0.92
|
||||
|
||||
|
||||
def test_training_result():
|
||||
result = TrainingResult(
|
||||
model_name="random_forest",
|
||||
metrics=ModelMetrics(
|
||||
accuracy=0.95,
|
||||
precision=0.90,
|
||||
recall=0.85,
|
||||
f1_score=0.87,
|
||||
pr_auc=0.92
|
||||
),
|
||||
confusion_matrix=[[100, 5], [10, 85]]
|
||||
)
|
||||
assert result.model_name == "random_forest"
|
||||
assert result.metrics.accuracy == 0.95
|
||||
Loading…
Reference in New Issue
Block a user