sms-castle-walls/src/streamlit_app.py
d8ac31d62c fix(streamlit_app): 使用绝对路径加载数据集防止路径错误
修改数据加载逻辑,使用os.path构建绝对路径来确保在不同工作目录下都能正确找到数据集文件
2026-01-15 17:54:52 +08:00

553 lines
24 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import streamlit as st
import pandas as pd
import sys
import os
# 添加项目根目录到Python路径
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.agent import agent, get_agent
from src.data import load_data, preprocess_data, split_data
from src.models import train_model, save_model, load_model, compare_models
# 设置页面配置
st.set_page_config(
page_title="垃圾短信分类器",
page_icon="<EFBFBD>",
layout="wide",
initial_sidebar_state="expanded"
)
# 自定义CSS - 欧洲中世纪风格
st.markdown("""<style>
/* 基础样式 */
body {
background-color: #1a1a2e;
color: #e0e0e0;
font-family: 'Georgia', serif;
}
/* 标题样式 */
.stTitle {
color: #d4af37;
font-family: 'Garamond', serif;
text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.8);
border-bottom: 2px solid #d4af37;
padding-bottom: 10px;
}
/* 侧边栏样式 */
.stSidebar {
background-color: #16213e;
border-right: 2px solid #d4af37;
}
/* 卡片和容器 */
.stExpander, .stContainer {
background-color: #0f3460;
border: 1px solid #d4af37;
border-radius: 8px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);
}
/* 按钮样式 */
.stButton > button {
background-color: #c8102e;
color: #ffffff;
border: 2px solid #d4af37;
border-radius: 8px;
font-family: 'Georgia', serif;
font-weight: bold;
padding: 10px 20px;
transition: all 0.3s ease;
}
.stButton > button:hover {
background-color: #8b0000;
color: #d4af37;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);
transform: translateY(-2px);
}
/* 输入框样式 */
.stTextArea, .stSelectbox > div {
background-color: #16213e;
border: 2px solid #d4af37;
border-radius: 8px;
color: #e0e0e0;
font-family: 'Georgia', serif;
}
/* 标题和文本样式 */
h1, h2, h3, h4, h5, h6 {
color: #d4af37;
font-family: 'Garamond', serif;
}
/* 分隔线 */
hr {
border: 1px solid #d4af37;
}
/* 信息卡片 */
.stAlert {
background-color: #0f3460;
border: 2px solid #d4af37;
border-radius: 8px;
color: #e0e0e0;
}
/* 页脚 */
footer {
color: #d4af37;
font-family: 'Georgia', serif;
text-align: center;
padding: 20px;
border-top: 2px solid #d4af37;
margin-top: 40px;
}
</style>""", unsafe_allow_html=True)
# 应用标题 - 中世纪风格
st.markdown("""
<div style="text-align: center; padding: 20px; border: 3px solid #d4af37; border-radius: 10px; background-color: #16213e; box-shadow: 0 8px 16px rgba(0, 0, 0, 0.8);">
<h1 style="color: #d4af37; font-family: 'Garamond', serif; text-shadow: 3px 3px 6px rgba(0, 0, 0, 0.8); margin: 0;">
⚔️ 中世纪垃圾短信分类器
</h1>
<p style="color: #e0e0e0; font-style: italic; margin-top: 10px;">
保护您的通信,抵御垃圾信息的入侵
</p>
</div>
""", unsafe_allow_html=True)
st.markdown("---")
# 侧边栏 - 中世纪风格
with st.sidebar:
st.markdown("""
<div style="text-align: center; padding: 10px; border-bottom: 2px solid #d4af37; margin-bottom: 20px;">
<h2 style="color: #d4af37; font-family: 'Garamond', serif; margin: 0;">
🛡️ 骑士工坊
</h2>
<p style="color: #e0e0e0; font-size: 14px; margin: 5px 0 0;">
系统配置
</p>
</div>
""", unsafe_allow_html=True)
# 模型选择 - 中世纪风格
st.markdown("""
<div style="margin-bottom: 20px;">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; font-size: 18px;">
⚔️ 选择武器
</h3>
<p style="color: #e0e0e0; font-size: 14px; margin: 5px 0 10px;">
选择用于抵御垃圾信息的武器
</p>
</div>
""", unsafe_allow_html=True)
model_option = st.selectbox(
label="",
options=["lightgbm", "logistic_regression"],
index=0,
format_func=lambda x: "圣光使者 (LightGBM)" if x == "lightgbm" else "智慧之剑 (Logistic Regression)"
)
# 语言选择 - 中世纪风格
st.markdown("""
<div style="margin-bottom: 20px;">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; font-size: 18px;">
📜 选择语言
</h3>
<p style="color: #e0e0e0; font-size: 14px; margin: 5px 0 10px;">
选择预言师的语言
</p>
</div>
""", unsafe_allow_html=True)
lang_option = st.selectbox(
label="",
options=["中文", "英文"],
index=0
)
# 系统说明 - 中世纪风格
st.markdown("---")
st.markdown("""
<div style="text-align: center; padding: 10px; border-bottom: 2px solid #d4af37; margin-bottom: 20px;">
<h2 style="color: #d4af37; font-family: 'Garamond', serif; margin: 0;">
🏰 关于城堡
</h2>
</div>
""", unsafe_allow_html=True)
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #d4af37; border-radius: 8px; padding: 15px; color: #e0e0e0; font-family: 'Georgia', serif;">
<p><strong>🛡️ 城堡防御系统</strong></p>
<p>这是一座由现代魔法和古老智慧构建的防御城堡:</p>
<ul>
<li>💫 使用圣光使者 (LightGBM) 和智慧之剑 (Logistic Regression) 守护</li>
<li>🧙 由DeepSeek预言师提供智慧解释</li>
<li>🤖 通过魔法使者 (Agent) 整合所有力量</li>
</ul>
<p style="margin-top: 15px; font-size: 14px; font-style: italic;">
保护您的通信不受垃圾信息的侵袭!
</p>
</div>
""", unsafe_allow_html=True)
# 主内容区域 - 中世纪风格
col1, col2 = st.columns([1, 1], gap="large")
with col1:
# 短信输入 - 中世纪风格
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #d4af37; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);">
<h2 style="color: #d4af37; font-family: 'Garamond', serif; text-align: center; border-bottom: 1px solid #d4af37; padding-bottom: 10px;">
📜 信件输入
</h2>
<p style="color: #e0e0e0; text-align: center; font-style: italic;">
输入需要检查的信件内容
</p>
</div>
""", unsafe_allow_html=True)
# 单条短信输入
sms_input = st.text_area(
label="",
height=200,
placeholder="例如WINNER!! As a valued network customer you have been selected to receivea £900 prize reward!",
help="输入需要分类的短信内容"
)
# 分类按钮 - 中世纪风格
classify_button = st.button(
"⚔️ 开始检查",
type="primary",
use_container_width=True,
disabled=sms_input.strip() == ""
)
# 批量上传功能 - 中世纪风格
st.markdown("---")
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #d4af37; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);">
<h2 style="color: #d4af37; font-family: 'Garamond', serif; text-align: center; border-bottom: 1px solid #d4af37; padding-bottom: 10px;">
📦 批量检查
</h2>
<p style="color: #e0e0e0; text-align: center; font-style: italic;">
上传多封信件进行批量检查
</p>
</div>
""", unsafe_allow_html=True)
uploaded_file = st.file_uploader(
label="",
type=["csv"],
help="上传包含短信文本的CSV文件需要包含text列"
)
# 模型训练功能(可选) - 中世纪风格
with st.expander("🔧 锻造武器", expanded=False):
st.markdown("""
<div style="background-color: #0f3460; border: 1px solid #d4af37; border-radius: 8px; padding: 15px;">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; margin-top: 0;">
铁匠工坊
</h3>
<p style="color: #e0e0e0; font-size: 14px;">
重新锻造您的武器,提升防御能力
</p>
</div>
""", unsafe_allow_html=True)
if st.button("⚒️ 重新锻造武器"):
with st.spinner("🔨 铁匠正在锻造武器..."):
try:
# 加载和预处理数据
# 使用绝对路径确保找到正确的数据集
import os
data_path = os.path.join(os.path.dirname(__file__), "..", "data", "spam.csv")
df = load_data(data_path)
processed_df = preprocess_data(df)
train_df, test_df = split_data(processed_df)
# 训练模型
model, params = train_model(train_df, model_type=model_option)
save_model(model, model_option)
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #d4af37; border-radius: 8px; padding: 15px; color: #d4af37; font-weight: bold;">
✨ 武器锻造完成!
<p style="color: #e0e0e0; font-weight: normal; margin-top: 10px;">
您的 {} 已准备好进行战斗!
</p>
</div>
"""
.format("圣光使者 (LightGBM)" if model_option == "lightgbm" else "智慧之剑 (Logistic Regression)"), unsafe_allow_html=True)
except Exception as e:
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #c8102e; border-radius: 8px; padding: 15px; color: #c8102e; font-weight: bold;">
❌ 锻造失败!
<p style="color: #e0e0e0; font-weight: normal; margin-top: 10px;">
铁匠遇到了问题:{}
</p>
</div>
"""
.format(e), unsafe_allow_html=True)
with col2:
# 分类结果显示
st.header("分类结果")
# 单条短信分类结果
if classify_button and sms_input.strip():
with st.spinner("正在分类..."):
try:
# 使用Agent进行分类和解释
result = agent.classify_and_explain(sms_input)
# 显示分类结果
st.subheader("📋 分类标签")
# 根据标签显示不同的样式 - 中世纪风格
if result['classification']['label'] == "spam":
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #c8102e; border-radius: 8px; padding: 15px; text-align: center; font-size: 18px; font-weight: bold;">
⚠️ <span style="color: #c8102e;">这是一封**垃圾信件**</span>
<p style="font-size: 14px; font-weight: normal; margin-top: 10px; color: #e0e0e0;">
建议您谨慎对待此信件!
</p>
</div>
""", unsafe_allow_html=True)
else:
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #228B22; border-radius: 8px; padding: 15px; text-align: center; font-size: 18px; font-weight: bold;">
✅ <span style="color: #228B22;">这是一封**正常信件**</span>
<p style="font-size: 14px; font-weight: normal; margin-top: 10px; color: #e0e0e0;">
此信件安全,可以放心阅读!
</p>
</div>
""", unsafe_allow_html=True)
# 显示概率 - 中世纪风格
st.markdown("""
<div style="margin-top: 20px; background-color: #0f3460; border: 2px solid #d4af37; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; text-align: center; border-bottom: 1px solid #d4af37; padding-bottom: 10px;">
📊 预言概率
</h3>
</div>
""", unsafe_allow_html=True)
prob_df = pd.DataFrame.from_dict(
result['classification']['probability'],
orient='index',
columns=['概率']
)
prob_df.index = ['垃圾信件', '正常信件'] if lang_option == '中文' else ['Spam', 'Ham']
st.bar_chart(prob_df)
# 显示详细结果 - 中世纪风格
st.markdown("""
<div style="margin-top: 20px; background-color: #0f3460; border: 2px solid #d4af37; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; text-align: center; border-bottom: 1px solid #d4af37; padding-bottom: 10px;">
<20> 详细预言
</h3>
</div>
""", unsafe_allow_html=True)
with st.expander("查看详细分类结果", expanded=True):
st.json(result['classification'], expanded=False)
# 显示解释和建议 - 中世纪风格
st.markdown("""
<div style="margin-top: 20px; background-color: #0f3460; border: 2px solid #d4af37; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; text-align: center; border-bottom: 1px solid #d4af37; padding-bottom: 10px;">
<20> 预言师的解释
</h3>
</div>
""", unsafe_allow_html=True)
with st.expander("查看预言解释", expanded=True):
st.markdown("""
<div style="background-color: #16213e; border: 1px solid #d4af37; border-radius: 8px; padding: 15px; color: #e0e0e0;">
<p><strong style="color: #d4af37;">📝 内容摘要</strong>{}</p>
<p><strong style="color: #d4af37;">⚖️ 预言原因</strong>{}</p>
<p><strong style="color: #d4af37;">🔮 可信度</strong>{} - {}</p>
</div>
"""
.format(
result['explanation']['content_summary'],
result['explanation']['classification_reason'],
result['explanation']['confidence_level'],
result['explanation']['confidence_explanation']
), unsafe_allow_html=True)
st.markdown("""
<div style="margin-top: 20px; background-color: #0f3460; border: 2px solid #d4af37; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; text-align: center; border-bottom: 1px solid #d4af37; padding-bottom: 10px;">
💡 行动建议
</h3>
</div>
""", unsafe_allow_html=True)
suggestion_html = """
<div style="background-color: #16213e; border: 1px solid #d4af37; border-radius: 8px; padding: 15px;">
<ol style="color: #e0e0e0; list-style-type: decimal; padding-left: 20px;">
"""
for i, suggestion in enumerate(result['explanation']['suggestions']):
suggestion_html += f"<li style='margin-bottom: 10px;'>{suggestion}</li>"
suggestion_html += """
</ol>
</div>
"""
st.markdown(suggestion_html, unsafe_allow_html=True)
except Exception as e:
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #c8102e; border-radius: 8px; padding: 15px; color: #c8102e;">
❌ 预言失败!
<p style="color: #e0e0e0; margin-top: 10px;">
预言师遇到了问题:{}
</p>
</div>
"""
.format(e), unsafe_allow_html=True)
# 批量分类结果 - 中世纪风格
if uploaded_file is not None:
with st.spinner("🧙‍♂️ 预言师正在批量解析信件..."):
try:
# 读取上传的文件
df = pd.read_csv(uploaded_file)
if "text" not in df.columns:
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #c8102e; border-radius: 8px; padding: 15px; color: #c8102e;">
❌ 预言失败!
<p style="color: #e0e0e0; margin-top: 10px;">
信件文件必须包含'text'
</p>
</div>
""", unsafe_allow_html=True)
else:
# 限制处理数量
max_rows = 100
if len(df) > max_rows:
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #d4af37; border-radius: 8px; padding: 15px; color: #d4af37;">
⚠️ 警告
<p style="color: #e0e0e0; margin-top: 10px;">
信件文件包含 {len(df)} 封信件,预言师将只解析前 {max_rows}
</p>
</div>
""", unsafe_allow_html=True)
df = df.head(max_rows)
# 批量分类
results = []
for text in df["text"].tolist():
result = agent.classify_and_explain(text)
results.append({
"text": text,
"label": result['classification']['label'],
"spam_probability": result['classification']['probability']['spam'],
"ham_probability": result['classification']['probability']['ham'],
"content_summary": result['explanation']['content_summary'],
"classification_reason": result['explanation']['classification_reason']
})
# 转换为DataFrame
results_df = pd.DataFrame(results)
# 显示结果统计 - 中世纪风格
st.markdown("""
<div style="margin-top: 20px; background-color: #0f3460; border: 2px solid #d4af37; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; text-align: center; border-bottom: 1px solid #d4af37; padding-bottom: 10px;">
📊 预言统计
</h3>
</div>
""", unsafe_allow_html=True)
label_counts = results_df["label"].value_counts()
label_counts.index = label_counts.index.map({"spam": "垃圾信件", "ham": "正常信件"})
st.bar_chart(label_counts)
# 显示结果表格 - 中世纪风格
st.markdown("""
<div style="margin-top: 20px; background-color: #0f3460; border: 2px solid #d4af37; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; text-align: center; border-bottom: 1px solid #d4af37; padding-bottom: 10px;">
<20> 预言结果
</h3>
</div>
""", unsafe_allow_html=True)
st.dataframe(
results_df,
use_container_width=True,
column_config={
"text": st.column_config.TextColumn("信件内容", width="medium"),
"label": st.column_config.TextColumn("预言标签"),
"spam_probability": st.column_config.ProgressColumn(
"垃圾信件概率",
format="%.2f",
min_value=0.0,
max_value=1.0
),
"ham_probability": st.column_config.ProgressColumn(
"正常信件概率",
format="%.2f",
min_value=0.0,
max_value=1.0
),
"content_summary": st.column_config.TextColumn("内容摘要", width="medium"),
"classification_reason": st.column_config.TextColumn("预言原因", width="medium")
}
)
# 下载结果 - 中世纪风格
st.markdown("""
<div style="margin-top: 20px; background-color: #0f3460; border: 2px solid #d4af37; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; text-align: center; border-bottom: 1px solid #d4af37; padding-bottom: 10px;">
💾 保存预言
</h3>
</div>
""", unsafe_allow_html=True)
csv = results_df.to_csv(index=False).encode('utf-8')
st.download_button(
label="📄 下载预言结果",
data=csv,
file_name="spam_classification_results.csv",
mime="text/csv",
use_container_width=True
)
except Exception as e:
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #c8102e; border-radius: 8px; padding: 15px; color: #c8102e;">
❌ 预言失败!
<p style="color: #e0e0e0; margin-top: 10px;">
预言师遇到了问题:{}
</p>
</div>
"""
.format(e), unsafe_allow_html=True)
# 页脚 - 中世纪风格
st.markdown("---")
st.markdown("""
<div style="text-align: center; padding: 20px; border-top: 2px solid #d4af37; margin-top: 40px;">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; margin-bottom: 10px;">
🏰 中世纪垃圾短信防御城堡
</h3>
<p style="color: #e0e0e0; font-family: 'Georgia', serif; font-size: 14px;">
© 2026 由骑士团建造 | 基于传统魔法 + LLM 预言 + Agent 使者
</p>
<p style="color: #d4af37; font-size: 12px; margin-top: 10px; font-style: italic;">
保护您的通信不受垃圾信息的侵袭!
</p>
</div>
""", unsafe_allow_html=True)