删除 src/streamlit_tweet_app.py
This commit is contained in:
parent
220c166ff7
commit
69dbd80152
@ -1,350 +0,0 @@
|
|||||||
"""Streamlit 演示应用 - 推文情感分析
|
|
||||||
|
|
||||||
航空推文情感分析 AI 助手 - 支持情感分类、解释和处置方案生成。
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import streamlit as st
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
from src.tweet_agent import TweetSentimentAgent, analyze_tweet
|
|
||||||
|
|
||||||
# Load env variables
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
st.set_page_config(page_title="航空推文情感分析", page_icon="✈️", layout="wide")
|
|
||||||
|
|
||||||
# Sidebar Configuration
|
|
||||||
st.sidebar.header("🔧 配置")
|
|
||||||
st.sidebar.markdown("### 模型信息")
|
|
||||||
st.sidebar.info(
|
|
||||||
"""
|
|
||||||
**模型**: VotingClassifier (5个基学习器)
|
|
||||||
- Logistic Regression
|
|
||||||
- Multinomial Naive Bayes
|
|
||||||
- Random Forest
|
|
||||||
- LightGBM
|
|
||||||
- XGBoost
|
|
||||||
|
|
||||||
**性能**: Macro-F1 = 0.7533 ✅
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
st.sidebar.markdown("---")
|
|
||||||
# Mode Selection
|
|
||||||
mode = st.sidebar.radio("功能选择", ["📝 单条分析", "📊 批量分析", "📈 数据概览"])
|
|
||||||
|
|
||||||
# Initialize session state
|
|
||||||
if "agent" not in st.session_state:
|
|
||||||
with st.spinner("🔄 加载模型..."):
|
|
||||||
st.session_state.agent = TweetSentimentAgent()
|
|
||||||
|
|
||||||
if "batch_results" not in st.session_state:
|
|
||||||
st.session_state.batch_results = []
|
|
||||||
|
|
||||||
|
|
||||||
# --- Helper Functions ---
|
|
||||||
|
|
||||||
|
|
||||||
def get_sentiment_emoji(sentiment: str) -> str:
|
|
||||||
"""获取情感对应的表情符号"""
|
|
||||||
emoji_map = {
|
|
||||||
"negative": "😠",
|
|
||||||
"neutral": "😐",
|
|
||||||
"positive": "😊",
|
|
||||||
}
|
|
||||||
return emoji_map.get(sentiment, "❓")
|
|
||||||
|
|
||||||
|
|
||||||
def get_sentiment_color(sentiment: str) -> str:
|
|
||||||
"""获取情感对应的颜色"""
|
|
||||||
color_map = {
|
|
||||||
"negative": "#ff6b6b",
|
|
||||||
"neutral": "#ffd93d",
|
|
||||||
"positive": "#6bcb77",
|
|
||||||
}
|
|
||||||
return color_map.get(sentiment, "#e0e0e0")
|
|
||||||
|
|
||||||
|
|
||||||
def get_priority_color(priority: str) -> str:
|
|
||||||
"""获取优先级对应的颜色"""
|
|
||||||
color_map = {
|
|
||||||
"high": "#ff4757",
|
|
||||||
"medium": "#ffa502",
|
|
||||||
"low": "#2ed573",
|
|
||||||
}
|
|
||||||
return color_map.get(priority, "#e0e0e0")
|
|
||||||
|
|
||||||
|
|
||||||
# --- Main Views ---
|
|
||||||
|
|
||||||
if mode == "📝 单条分析":
|
|
||||||
st.title("✈️ 航空推文情感分析")
|
|
||||||
st.markdown("输入推文文本,获取 AI 驱动的情感分析、解释和处置方案。")
|
|
||||||
|
|
||||||
# Input form
|
|
||||||
with st.form("tweet_analysis_form"):
|
|
||||||
col1, col2 = st.columns([3, 1])
|
|
||||||
|
|
||||||
with col1:
|
|
||||||
tweet_text = st.text_area(
|
|
||||||
"推文内容",
|
|
||||||
placeholder="@United This is the worst airline ever! My flight was delayed for 5 hours...",
|
|
||||||
height=100,
|
|
||||||
)
|
|
||||||
|
|
||||||
with col2:
|
|
||||||
airline = st.selectbox(
|
|
||||||
"航空公司",
|
|
||||||
["United", "US Airways", "American", "Southwest", "Delta", "Virgin America"],
|
|
||||||
)
|
|
||||||
|
|
||||||
submitted = st.form_submit_button("🔍 分析", type="primary")
|
|
||||||
|
|
||||||
if submitted and tweet_text:
|
|
||||||
with st.spinner("🤖 AI 正在分析..."):
|
|
||||||
try:
|
|
||||||
result = analyze_tweet(tweet_text, airline)
|
|
||||||
|
|
||||||
# Display results
|
|
||||||
st.divider()
|
|
||||||
|
|
||||||
# Header with sentiment
|
|
||||||
sentiment_emoji = get_sentiment_emoji(result.classification.sentiment)
|
|
||||||
sentiment_color = get_sentiment_color(result.classification.sentiment)
|
|
||||||
|
|
||||||
st.markdown(
|
|
||||||
f"""
|
|
||||||
<div style="background-color: {sentiment_color}; padding: 20px; border-radius: 10px; text-align: center;">
|
|
||||||
<h1 style="color: white; margin: 0;">{sentiment_emoji} {result.classification.sentiment.upper()}</h1>
|
|
||||||
<p style="color: white; margin: 10px 0 0 0;">置信度: {result.classification.confidence:.1%}</p>
|
|
||||||
</div>
|
|
||||||
""",
|
|
||||||
unsafe_allow_html=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
st.divider()
|
|
||||||
|
|
||||||
# Original tweet
|
|
||||||
st.subheader("📝 原始推文")
|
|
||||||
st.info(f"**航空公司**: {result.airline}\n\n**内容**: {result.tweet_text}")
|
|
||||||
|
|
||||||
# Explanation
|
|
||||||
st.subheader("🔍 情感解释")
|
|
||||||
st.markdown("**关键因素:**")
|
|
||||||
for factor in result.explanation.key_factors:
|
|
||||||
st.write(f"- {factor}")
|
|
||||||
|
|
||||||
st.markdown("**推理过程:**")
|
|
||||||
st.write(result.explanation.reasoning)
|
|
||||||
|
|
||||||
# Disposal plan
|
|
||||||
st.subheader("📋 处置方案")
|
|
||||||
|
|
||||||
priority_color = get_priority_color(result.disposal_plan.priority)
|
|
||||||
st.markdown(
|
|
||||||
f"""
|
|
||||||
<div style="background-color: {priority_color}; padding: 10px; border-radius: 5px; display: inline-block;">
|
|
||||||
<span style="color: white; font-weight: bold;">优先级: {result.disposal_plan.priority.upper()}</span>
|
|
||||||
</div>
|
|
||||||
<br><br>
|
|
||||||
**行动类型**: {result.disposal_plan.action_type}
|
|
||||||
""",
|
|
||||||
unsafe_allow_html=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.disposal_plan.suggested_response:
|
|
||||||
st.markdown("**建议回复:**")
|
|
||||||
st.success(result.disposal_plan.suggested_response)
|
|
||||||
|
|
||||||
st.markdown("**后续行动:**")
|
|
||||||
for action in result.disposal_plan.follow_up_actions:
|
|
||||||
st.write(f"- {action}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
st.error(f"分析失败: {e!s}")
|
|
||||||
|
|
||||||
elif mode == "📊 批量分析":
|
|
||||||
st.title("📊 批量推文分析")
|
|
||||||
st.markdown("上传 CSV 文件或输入多条推文,进行批量情感分析。")
|
|
||||||
|
|
||||||
# Input method selection
|
|
||||||
input_method = st.radio("输入方式", ["手动输入", "CSV 上传"], horizontal=True)
|
|
||||||
|
|
||||||
if input_method == "手动输入":
|
|
||||||
st.markdown("### 输入推文(每行一条)")
|
|
||||||
tweets_input = st.text_area(
|
|
||||||
"推文列表",
|
|
||||||
placeholder="@United Flight delayed again!\n@Southwest Great service!\n@American Baggage policy?",
|
|
||||||
height=200,
|
|
||||||
)
|
|
||||||
|
|
||||||
if st.button("🔍 批量分析", type="primary") and tweets_input:
|
|
||||||
lines = [line.strip() for line in tweets_input.split("\n") if line.strip()]
|
|
||||||
|
|
||||||
if lines:
|
|
||||||
with st.spinner("🤖 AI 正在分析..."):
|
|
||||||
results = []
|
|
||||||
for line in lines:
|
|
||||||
try:
|
|
||||||
# Extract airline from tweet (simple heuristic)
|
|
||||||
airline = "United" # Default
|
|
||||||
for a in ["United", "US Airways", "American", "Southwest", "Delta", "Virgin America"]:
|
|
||||||
if a.lower() in line.lower():
|
|
||||||
airline = a
|
|
||||||
break
|
|
||||||
|
|
||||||
result = analyze_tweet(line, airline)
|
|
||||||
results.append(result)
|
|
||||||
except Exception as e:
|
|
||||||
st.warning(f"分析失败: {line[:50]}... - {e}")
|
|
||||||
|
|
||||||
if results:
|
|
||||||
st.session_state.batch_results = results
|
|
||||||
st.success(f"✅ 成功分析 {len(results)} 条推文")
|
|
||||||
|
|
||||||
else: # CSV upload
|
|
||||||
st.markdown("### 上传 CSV 文件")
|
|
||||||
st.info("CSV 文件应包含以下列: `text` (推文内容), `airline` (航空公司)")
|
|
||||||
|
|
||||||
uploaded_file = st.file_uploader("选择文件", type=["csv"])
|
|
||||||
|
|
||||||
if uploaded_file and st.button("🔍 分析上传文件", type="primary"):
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
try:
|
|
||||||
df = pd.read_csv(uploaded_file)
|
|
||||||
|
|
||||||
if "text" not in df.columns:
|
|
||||||
st.error("CSV 文件必须包含 'text' 列")
|
|
||||||
else:
|
|
||||||
with st.spinner("🤖 AI 正在分析..."):
|
|
||||||
results = []
|
|
||||||
for _, row in df.iterrows():
|
|
||||||
try:
|
|
||||||
text = row["text"]
|
|
||||||
airline = row.get("airline", "United")
|
|
||||||
result = analyze_tweet(text, airline)
|
|
||||||
results.append(result)
|
|
||||||
except Exception as e:
|
|
||||||
st.warning(f"分析失败: {text[:50]}... - {e}")
|
|
||||||
|
|
||||||
if results:
|
|
||||||
st.session_state.batch_results = results
|
|
||||||
st.success(f"✅ 成功分析 {len(results)} 条推文")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
st.error(f"文件读取失败: {e!s}")
|
|
||||||
|
|
||||||
# Display batch results
|
|
||||||
if st.session_state.batch_results:
|
|
||||||
st.divider()
|
|
||||||
st.subheader(f"📊 分析结果 ({len(st.session_state.batch_results)} 条)")
|
|
||||||
|
|
||||||
# Summary statistics
|
|
||||||
sentiments = [r.classification.sentiment for r in st.session_state.batch_results]
|
|
||||||
negative_count = sentiments.count("negative")
|
|
||||||
neutral_count = sentiments.count("neutral")
|
|
||||||
positive_count = sentiments.count("positive")
|
|
||||||
|
|
||||||
col1, col2, col3 = st.columns(3)
|
|
||||||
col1.metric("😠 负面", negative_count)
|
|
||||||
col2.metric("😐 中性", neutral_count)
|
|
||||||
col3.metric("😊 正面", positive_count)
|
|
||||||
|
|
||||||
# Detailed results table
|
|
||||||
st.markdown("### 详细结果")
|
|
||||||
|
|
||||||
results_data = []
|
|
||||||
for r in st.session_state.batch_results:
|
|
||||||
results_data.append({
|
|
||||||
"推文": r.tweet_text[:50] + "..." if len(r.tweet_text) > 50 else r.tweet_text,
|
|
||||||
"航空公司": r.airline,
|
|
||||||
"情感": f"{get_sentiment_emoji(r.classification.sentiment)} {r.classification.sentiment}",
|
|
||||||
"置信度": f"{r.classification.confidence:.1%}",
|
|
||||||
"优先级": r.disposal_plan.priority.upper(),
|
|
||||||
"行动类型": r.disposal_plan.action_type,
|
|
||||||
})
|
|
||||||
|
|
||||||
st.dataframe(results_data, use_container_width=True)
|
|
||||||
|
|
||||||
# Clear button
|
|
||||||
if st.button("🗑️ 清除结果"):
|
|
||||||
st.session_state.batch_results = []
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
elif mode == "📈 数据概览":
|
|
||||||
st.title("📈 数据集概览")
|
|
||||||
st.markdown("查看训练数据集的统计信息。")
|
|
||||||
|
|
||||||
try:
|
|
||||||
import polars as pl
|
|
||||||
from src.tweet_data import load_cleaned_tweets, print_data_summary
|
|
||||||
|
|
||||||
df = load_cleaned_tweets("data/Tweets_cleaned.csv")
|
|
||||||
|
|
||||||
# Display summary
|
|
||||||
st.subheader("📊 数据统计")
|
|
||||||
print_data_summary(df, "数据集统计")
|
|
||||||
|
|
||||||
# Display sample data
|
|
||||||
st.subheader("📝 样本数据")
|
|
||||||
sample_df = df.head(10).to_pandas()
|
|
||||||
st.dataframe(sample_df, use_container_width=True)
|
|
||||||
|
|
||||||
# Sentiment distribution chart
|
|
||||||
st.subheader("📈 情感分布")
|
|
||||||
sentiment_counts = df.group_by("airline_sentiment").agg(
|
|
||||||
pl.col("airline_sentiment").count().alias("count")
|
|
||||||
).sort("count", descending=True)
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import plotly.express as px
|
|
||||||
|
|
||||||
sentiment_df = sentiment_counts.to_pandas()
|
|
||||||
fig = px.pie(
|
|
||||||
sentiment_df,
|
|
||||||
values="count",
|
|
||||||
names="airline_sentiment",
|
|
||||||
title="情感分布",
|
|
||||||
color_discrete_map={
|
|
||||||
"negative": "#ff6b6b",
|
|
||||||
"neutral": "#ffd93d",
|
|
||||||
"positive": "#6bcb77",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
st.plotly_chart(fig, use_container_width=True)
|
|
||||||
|
|
||||||
# Airline distribution chart
|
|
||||||
st.subheader("✈️ 航空公司分布")
|
|
||||||
airline_counts = df.group_by("airline").agg(
|
|
||||||
pl.col("airline").count().alias("count")
|
|
||||||
).sort("count", descending=True)
|
|
||||||
|
|
||||||
airline_df = airline_counts.to_pandas()
|
|
||||||
fig = px.bar(
|
|
||||||
airline_df,
|
|
||||||
x="airline",
|
|
||||||
y="count",
|
|
||||||
title="各航空公司推文数量",
|
|
||||||
color="count",
|
|
||||||
color_continuous_scale="Blues",
|
|
||||||
)
|
|
||||||
st.plotly_chart(fig, use_container_width=True)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
st.error(f"数据加载失败: {e!s}")
|
|
||||||
|
|
||||||
# Footer
|
|
||||||
st.divider()
|
|
||||||
st.markdown(
|
|
||||||
"""
|
|
||||||
<div style="text-align: center; color: gray; font-size: 12px;">
|
|
||||||
航空推文情感分析 AI 助手 | 基于 VotingClassifier (LR + NB + RF + LightGBM + XGBoost)
|
|
||||||
</div>
|
|
||||||
""",
|
|
||||||
unsafe_allow_html=True,
|
|
||||||
)
|
|
||||||
Loading…
Reference in New Issue
Block a user