import streamlit as st import pandas as pd import sys import os # 添加项目根目录到Python路径 sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from src.agent import agent, get_agent from src.data import load_data, preprocess_data, split_data from src.models import train_model, save_model, load_model, compare_models # 设置页面配置 st.set_page_config( page_title="垃圾短信分类系统", page_icon="📱", layout="wide", initial_sidebar_state="expanded" ) # 应用标题 st.title("📱 垃圾短信分类系统") st.markdown("---") # 侧边栏 with st.sidebar: st.header("系统配置") # 模型选择 model_option = st.selectbox( "选择模型", options=["lightgbm", "logistic_regression"], index=0, help="选择用于分类的机器学习模型" ) # 语言选择 lang_option = st.selectbox( "输出语言", options=["中文", "英文"], index=0, help="选择分类结果和解释的输出语言" ) # 系统说明 st.markdown("---") st.header("关于系统") st.info( "这是一个基于传统机器学习 + LLM + Agent的垃圾短信分类系统。\n"\ "- 使用LightGBM和Logistic Regression进行分类\n"\ "- 利用DeepSeek LLM解释分类结果\n"\ "- 通过Agent实现工具调用和结果整合" ) # 主内容区域 col1, col2 = st.columns([1, 1], gap="large") with col1: # 短信输入 st.header("输入短信") # 单条短信输入 sms_input = st.text_area( "请输入要分类的短信", height=200, placeholder="例如:WINNER!! As a valued network customer you have been selected to receivea £900 prize reward!" ) # 分类按钮 classify_button = st.button( "📊 开始分类", type="primary", use_container_width=True, disabled=sms_input.strip() == "" ) # 批量上传功能 st.markdown("---") st.header("批量分类") uploaded_file = st.file_uploader( "上传CSV文件(包含text列)", type=["csv"], help="上传包含短信文本的CSV文件,系统将自动分类" ) # 模型训练功能(可选) with st.expander("🔧 模型训练", expanded=False): if st.button("重新训练模型"): with st.spinner("正在训练模型..."): try: # 加载和预处理数据 df = load_data("../data/spam.csv") processed_df = preprocess_data(df) train_df, test_df = split_data(processed_df) # 训练模型 model, params = train_model(train_df, model_type=model_option) save_model(model, model_option) st.success(f"{model_option} 模型训练完成!") except Exception as e: st.error(f"模型训练失败:{e}") with col2: # 分类结果显示 st.header("分类结果") # 单条短信分类结果 if classify_button and sms_input.strip(): with st.spinner("正在分类..."): try: # 使用Agent进行分类和解释 result = agent.classify_and_explain(sms_input) # 显示分类结果 st.subheader("📋 分类标签") # 根据标签显示不同的样式 if result['classification']['label'] == "spam": st.error(f"⚠️ 这是一条**垃圾短信**") else: st.success(f"✅ 这是一条**正常短信**") # 显示概率 st.subheader("📊 分类概率") prob_df = pd.DataFrame.from_dict( result['classification']['probability'], orient='index', columns=['概率'] ) st.bar_chart(prob_df) # 显示详细结果 st.subheader("📝 详细结果") with st.expander("查看详细分类结果", expanded=True): st.json(result['classification'], expanded=False) # 显示解释和建议 st.subheader("🤔 结果解释") with st.expander("查看分类解释", expanded=True): st.write(f"**内容摘要**:{result['explanation']['content_summary']}") st.write(f"**分类原因**:{result['explanation']['classification_reason']}") st.write(f"**可信度**:{result['explanation']['confidence_level']} - {result['explanation']['confidence_explanation']}") st.subheader("💡 行动建议") for i, suggestion in enumerate(result['explanation']['suggestions']): st.write(f"{i+1}. {suggestion}") except Exception as e: st.error(f"分类失败:{e}") # 批量分类结果 if uploaded_file is not None: with st.spinner("正在批量分类..."): try: # 读取上传的文件 df = pd.read_csv(uploaded_file) if "text" not in df.columns: st.error("CSV文件必须包含'text'列") else: # 限制处理数量 max_rows = 100 if len(df) > max_rows: st.warning(f"文件包含 {len(df)} 条记录,仅处理前 {max_rows} 条") df = df.head(max_rows) # 批量分类 results = [] for text in df["text"].tolist(): result = agent.classify_and_explain(text) results.append({ "text": text, "label": result['classification']['label'], "spam_probability": result['classification']['probability']['spam'], "ham_probability": result['classification']['probability']['ham'], "content_summary": result['explanation']['content_summary'], "classification_reason": result['explanation']['classification_reason'] }) # 转换为DataFrame results_df = pd.DataFrame(results) # 显示结果统计 st.subheader("📊 分类统计") label_counts = results_df["label"].value_counts() st.bar_chart(label_counts) # 显示结果表格 st.subheader("📋 分类结果") st.dataframe( results_df, use_container_width=True, column_config={ "text": st.column_config.TextColumn("短信内容", width="medium"), "label": st.column_config.TextColumn("分类标签"), "spam_probability": st.column_config.ProgressColumn( "垃圾短信概率", format="%.2f", min_value=0.0, max_value=1.0 ), "ham_probability": st.column_config.ProgressColumn( "正常短信概率", format="%.2f", min_value=0.0, max_value=1.0 ), "content_summary": st.column_config.TextColumn("内容摘要", width="medium"), "classification_reason": st.column_config.TextColumn("分类原因", width="medium") } ) # 下载结果 st.subheader("💾 下载结果") csv = results_df.to_csv(index=False).encode('utf-8') st.download_button( label="下载分类结果 (CSV)", data=csv, file_name="spam_classification_results.csv", mime="text/csv", use_container_width=True ) except Exception as e: st.error(f"批量分类失败:{e}") # 页脚 st.markdown("---") st.markdown( "
© 2026 垃圾短信分类系统 | 基于传统机器学习 + LLM + Agent
", unsafe_allow_html=True )