feat(agent): 添加LLM解释功能并本地化界面
添加使用DeepSeek API生成自然语言解释的功能 将streamlit界面从英文翻译为中文 更新依赖项添加requests库 移除.gitignore中不必要的数据文件排除规则
This commit is contained in:
parent
990323d5d6
commit
4553a1794d
3
.gitignore
vendored
3
.gitignore
vendored
@ -17,6 +17,3 @@ wheels/
|
|||||||
# Artifacts and models
|
# Artifacts and models
|
||||||
artifacts/
|
artifacts/
|
||||||
*.joblib
|
*.joblib
|
||||||
|
|
||||||
# Large data files
|
|
||||||
data/*.csv
|
|
||||||
|
|||||||
25001
data/Cleaned_Customer_Sentiment.csv
Normal file
25001
data/Cleaned_Customer_Sentiment.csv
Normal file
File diff suppressed because it is too large
Load Diff
25001
data/Customer_Sentiment.csv
Normal file
25001
data/Customer_Sentiment.csv
Normal file
File diff suppressed because it is too large
Load Diff
@ -15,4 +15,8 @@ dependencies = [
|
|||||||
"streamlit>=1.52.2",
|
"streamlit>=1.52.2",
|
||||||
"pydantic>=2.9.2",
|
"pydantic>=2.9.2",
|
||||||
"python-dotenv>=1.0.0",
|
"python-dotenv>=1.0.0",
|
||||||
|
"requests>=2.32.0",
|
||||||
]
|
]
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "tencent"
|
||||||
|
url = "https://mirrors.cloud.tencent.com/pypi/simple/"
|
||||||
76
src/agent.py
76
src/agent.py
@ -2,6 +2,7 @@ import os
|
|||||||
import joblib
|
import joblib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import requests
|
||||||
from typing import Literal, Annotated
|
from typing import Literal, Annotated
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -75,3 +76,78 @@ def explain_features(features: CustomerFeatures | dict) -> list[str]:
|
|||||||
|
|
||||||
def explain_features_model(features: CustomerFeatures | dict) -> ExplanationOutput:
|
def explain_features_model(features: CustomerFeatures | dict) -> ExplanationOutput:
|
||||||
return ExplanationOutput(factors=explain_features(features))
|
return ExplanationOutput(factors=explain_features(features))
|
||||||
|
|
||||||
|
def explain_features_with_llm(features: CustomerFeatures | dict, api_key: str) -> str:
|
||||||
|
"""Use LLM to generate natural language explanation for the risk factors"""
|
||||||
|
_ensure_loaded()
|
||||||
|
explanations = explain_features(features)
|
||||||
|
|
||||||
|
# Map feature names to human-readable descriptions
|
||||||
|
feature_mapping = {
|
||||||
|
'cat__gender_male': '男性',
|
||||||
|
'cat__gender_female': '女性',
|
||||||
|
'cat__gender_other': '其他性别',
|
||||||
|
'cat__age_group_18-25': '18-25岁年龄段',
|
||||||
|
'cat__age_group_26-35': '26-35岁年龄段',
|
||||||
|
'cat__age_group_36-45': '36-45岁年龄段',
|
||||||
|
'cat__age_group_46-60': '46-60岁年龄段',
|
||||||
|
'cat__age_group_60+': '60岁以上年龄段',
|
||||||
|
'cat__region_north': '北部地区',
|
||||||
|
'cat__region_south': '南部地区',
|
||||||
|
'cat__region_east': '东部地区',
|
||||||
|
'cat__region_west': '西部地区',
|
||||||
|
'cat__region_central': '中部地区',
|
||||||
|
'cat__purchase_channel_online': '线上购买渠道',
|
||||||
|
'cat__purchase_channel_offline': '线下购买渠道',
|
||||||
|
'cat__issue_resolved_True': '问题已解决',
|
||||||
|
'cat__issue_resolved_False': '问题未解决',
|
||||||
|
'cat__complaint_registered_True': '已注册投诉',
|
||||||
|
'cat__complaint_registered_False': '未注册投诉',
|
||||||
|
'num__response_time_hours': '响应时间(小时)'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Convert explanations to human-readable format
|
||||||
|
human_explanations = []
|
||||||
|
for exp in explanations:
|
||||||
|
for feature, desc in feature_mapping.items():
|
||||||
|
if feature in exp:
|
||||||
|
# Replace feature name with description
|
||||||
|
human_exp = exp.replace(feature, desc)
|
||||||
|
# Make the text more natural
|
||||||
|
human_exp = human_exp.replace('increase negative risk', '增加了负面情绪风险')
|
||||||
|
human_exp = human_exp.replace('decrease negative risk', '降低了负面情绪风险')
|
||||||
|
human_exp = human_exp.replace('weight=', '权重为')
|
||||||
|
human_explanations.append(human_exp)
|
||||||
|
break
|
||||||
|
|
||||||
|
if not human_explanations:
|
||||||
|
# Fallback if no feature mappings found
|
||||||
|
human_explanations = [exp.replace('increase negative risk', '增加了负面情绪风险').replace('decrease negative risk', '降低了负面情绪风险') for exp in explanations]
|
||||||
|
|
||||||
|
# Use DeepSeek API to generate natural language explanation
|
||||||
|
prompt = f"请将以下客户负面情绪风险因素分析结果转化为一段自然、流畅的中文解释,用于向客服人员展示:\n\n{chr(10).join(human_explanations)}\n\n要求:\n1. 用简洁的语言说明主要风险因素\n2. 突出影响最大的几个因素\n3. 保持专业但易于理解\n4. 不要使用技术术语\n5. 总长度控制在100-200字之间"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
"https://api.deepseek.com/v1/chat/completions",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": "deepseek-chat",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "你是一位专业的客户分析专家,擅长将复杂的数据分析结果转化为通俗易懂的解释。"},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
],
|
||||||
|
"max_tokens": 200,
|
||||||
|
"temperature": 0.7
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
return result["choices"][0]["message"]["content"]
|
||||||
|
except Exception as e:
|
||||||
|
# Fallback to simple concatenation if API call fails
|
||||||
|
return f"客户负面情绪风险分析:{chr(10).join(human_explanations)}"
|
||||||
|
|||||||
@ -2,10 +2,10 @@ import streamlit as st
|
|||||||
import polars as pl
|
import polars as pl
|
||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from agent import CustomerFeatures, predict_risk, explain_features
|
from agent import CustomerFeatures, predict_risk, explain_features, explain_features_with_llm
|
||||||
import altair as alt
|
import altair as alt
|
||||||
|
|
||||||
st.set_page_config(page_title="Customer Sentiment Analysis", layout="wide")
|
st.set_page_config(page_title="客户情感分析", layout="wide")
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
st.markdown(
|
st.markdown(
|
||||||
@ -116,10 +116,10 @@ def map_risk_to_advice(risk: float):
|
|||||||
return level, decision, actions
|
return level, decision, actions
|
||||||
|
|
||||||
|
|
||||||
st.title("📊 Customer Sentiment Analysis Dashboard")
|
st.title("📊 客户情感分析仪表板")
|
||||||
|
|
||||||
st.image(
|
st.image(
|
||||||
"https://via.placeholder.com/1200x200/fffde7/e65100?text=Customer+Sentiment+Prediction+System",
|
"https://via.placeholder.com/1200x200/fffde7/e65100?text=客户情感预测系统",
|
||||||
use_container_width=True,
|
use_container_width=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -128,7 +128,7 @@ st.image(
|
|||||||
def load_data():
|
def load_data():
|
||||||
file_path = os.path.join("data", "Cleaned_Customer_Sentiment.csv")
|
file_path = os.path.join("data", "Cleaned_Customer_Sentiment.csv")
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
st.error(f"Data file not found at {file_path}. Please run the data processing script first.")
|
st.error(f"在{file_path}找不到数据文件。请先运行数据处理脚本。")
|
||||||
return None
|
return None
|
||||||
return pl.read_csv(file_path)
|
return pl.read_csv(file_path)
|
||||||
|
|
||||||
@ -136,11 +136,11 @@ def load_data():
|
|||||||
df = load_data()
|
df = load_data()
|
||||||
|
|
||||||
if df is not None:
|
if df is not None:
|
||||||
st.sidebar.header("Filters")
|
st.sidebar.header("筛选条件")
|
||||||
sentiments = df["sentiment"].unique().to_list()
|
sentiments = df["sentiment"].unique().to_list()
|
||||||
selected_sentiment = st.sidebar.multiselect("Select Sentiment", sentiments, default=sentiments)
|
selected_sentiment = st.sidebar.multiselect("选择情感", sentiments, default=sentiments)
|
||||||
regions = df["region"].unique().to_list()
|
regions = df["region"].unique().to_list()
|
||||||
selected_region = st.sidebar.multiselect("Select Region", regions, default=regions)
|
selected_region = st.sidebar.multiselect("选择地区", regions, default=regions)
|
||||||
|
|
||||||
filtered_df = df.filter(
|
filtered_df = df.filter(
|
||||||
(pl.col("sentiment").is_in(selected_sentiment)) &
|
(pl.col("sentiment").is_in(selected_sentiment)) &
|
||||||
@ -149,19 +149,19 @@ if df is not None:
|
|||||||
|
|
||||||
st.markdown('<div id="overview"></div>', unsafe_allow_html=True)
|
st.markdown('<div id="overview"></div>', unsafe_allow_html=True)
|
||||||
col1, col2, col3 = st.columns(3)
|
col1, col2, col3 = st.columns(3)
|
||||||
col1.metric("Total Reviews", filtered_df.height)
|
col1.metric("总评论数", filtered_df.height)
|
||||||
col2.metric("Positive Reviews", filtered_df.filter(pl.col("sentiment") == "positive").height)
|
col2.metric("正面评论数", filtered_df.filter(pl.col("sentiment") == "positive").height)
|
||||||
col3.metric("Avg Rating", f"{filtered_df['customer_rating'].mean():.2f}")
|
col3.metric("平均评分", f"{filtered_df['customer_rating'].mean():.2f}")
|
||||||
|
|
||||||
st.markdown('<div id="stats"></div>', unsafe_allow_html=True)
|
st.markdown('<div id="stats"></div>', unsafe_allow_html=True)
|
||||||
st.subheader("Sentiment Distribution")
|
st.subheader("情感分布")
|
||||||
sentiment_counts = filtered_df["sentiment"].value_counts()
|
sentiment_counts = filtered_df["sentiment"].value_counts()
|
||||||
sentiment_chart = (
|
sentiment_chart = (
|
||||||
alt.Chart(sentiment_counts.to_pandas())
|
alt.Chart(sentiment_counts.to_pandas())
|
||||||
.mark_bar(cornerRadiusTopLeft=6, cornerRadiusTopRight=6)
|
.mark_bar(cornerRadiusTopLeft=6, cornerRadiusTopRight=6)
|
||||||
.encode(
|
.encode(
|
||||||
x=alt.X("sentiment:N", title="Sentiment"),
|
x=alt.X("sentiment:N", title="情感"),
|
||||||
y=alt.Y("count:Q", title="Count"),
|
y=alt.Y("count:Q", title="数量"),
|
||||||
color=alt.Color(
|
color=alt.Color(
|
||||||
"sentiment:N",
|
"sentiment:N",
|
||||||
scale=alt.Scale(
|
scale=alt.Scale(
|
||||||
@ -175,92 +175,92 @@ if df is not None:
|
|||||||
)
|
)
|
||||||
st.altair_chart(sentiment_chart, use_container_width=True)
|
st.altair_chart(sentiment_chart, use_container_width=True)
|
||||||
|
|
||||||
st.subheader("Data Preview")
|
st.subheader("数据预览")
|
||||||
st.dataframe(filtered_df.to_pandas())
|
st.dataframe(filtered_df.to_pandas())
|
||||||
|
|
||||||
banner_col1, banner_col2 = st.columns([2, 1])
|
banner_col1, banner_col2 = st.columns([2, 1])
|
||||||
with banner_col1:
|
with banner_col1:
|
||||||
st.subheader("Configuration Check")
|
st.subheader("配置检查")
|
||||||
api_key = os.getenv("API_KEY")
|
api_key = os.getenv("API_KEY")
|
||||||
if api_key:
|
if api_key:
|
||||||
st.success(f"API Key loaded: {api_key[:4]}****")
|
st.success(f"API密钥已加载: {api_key[:4]}****")
|
||||||
else:
|
else:
|
||||||
st.warning("API Key not found in environment variables.")
|
st.warning("在环境变量中未找到API密钥。")
|
||||||
with banner_col2:
|
with banner_col2:
|
||||||
st.image(
|
st.image(
|
||||||
"https://via.placeholder.com/400x200/e8f5e9/1b5e20?text=Customer+Care",
|
"https://via.placeholder.com/400x200/e8f5e9/1b5e20?text=客户关怀",
|
||||||
use_container_width=True,
|
use_container_width=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
st.subheader("Cleaning Stats")
|
st.subheader("数据清洗统计")
|
||||||
nulls_df = filtered_df.select([pl.col(c).is_null().sum().alias(c) for c in filtered_df.columns]).melt()
|
nulls_df = filtered_df.select([pl.col(c).is_null().sum().alias(c) for c in filtered_df.columns]).melt()
|
||||||
nulls_pd = nulls_df.rename({"variable": "feature", "value": "nulls"}).to_pandas()
|
nulls_pd = nulls_df.rename({"variable": "feature", "value": "nulls"}).to_pandas()
|
||||||
c1, c2 = st.columns(2)
|
c1, c2 = st.columns(2)
|
||||||
c1.metric("Issue Resolved Rate", f"{float(filtered_df['issue_resolved'].cast(pl.Boolean).mean())*100:.2f}%")
|
c1.metric("问题解决率", f"{float(filtered_df['issue_resolved'].cast(pl.Boolean).mean())*100:.2f}%")
|
||||||
c2.metric("Complaint Registered Rate", f"{float(filtered_df['complaint_registered'].cast(pl.Boolean).mean())*100:.2f}%")
|
c2.metric("投诉登记率", f"{float(filtered_df['complaint_registered'].cast(pl.Boolean).mean())*100:.2f}%")
|
||||||
st.subheader("Nulls Per Column")
|
st.subheader("每列空值统计")
|
||||||
nulls_chart = (
|
nulls_chart = (
|
||||||
alt.Chart(nulls_pd)
|
alt.Chart(nulls_pd)
|
||||||
.mark_bar(cornerRadiusTopLeft=4, cornerRadiusTopRight=4)
|
.mark_bar(cornerRadiusTopLeft=4, cornerRadiusTopRight=4)
|
||||||
.encode(
|
.encode(
|
||||||
x=alt.X("feature:N", sort="-y", title="Feature"),
|
x=alt.X("feature:N", sort="-y", title="特征"),
|
||||||
y=alt.Y("nulls:Q", title="Null Count"),
|
y=alt.Y("nulls:Q", title="空值数量"),
|
||||||
color=alt.value("#80cbc4"),
|
color=alt.value("#80cbc4"),
|
||||||
tooltip=["feature", "nulls"],
|
tooltip=["feature", "nulls"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
st.altair_chart(nulls_chart, use_container_width=True)
|
st.altair_chart(nulls_chart, use_container_width=True)
|
||||||
|
|
||||||
st.subheader("Gender Distribution")
|
st.subheader("性别分布")
|
||||||
gender_counts = filtered_df["gender"].value_counts()
|
gender_counts = filtered_df["gender"].value_counts()
|
||||||
gender_chart = (
|
gender_chart = (
|
||||||
alt.Chart(gender_counts.to_pandas())
|
alt.Chart(gender_counts.to_pandas())
|
||||||
.mark_bar(cornerRadiusTopLeft=6, cornerRadiusTopRight=6)
|
.mark_bar(cornerRadiusTopLeft=6, cornerRadiusTopRight=6)
|
||||||
.encode(
|
.encode(
|
||||||
x=alt.X("gender:N", title="Gender"),
|
x=alt.X("gender:N", title="性别"),
|
||||||
y=alt.Y("count:Q", title="Count"),
|
y=alt.Y("count:Q", title="数量"),
|
||||||
color=alt.value("#90caf9"),
|
color=alt.value("#90caf9"),
|
||||||
tooltip=["gender", "count"],
|
tooltip=["gender", "count"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
st.altair_chart(gender_chart, use_container_width=True)
|
st.altair_chart(gender_chart, use_container_width=True)
|
||||||
|
|
||||||
st.subheader("Avg Response Time by Sentiment")
|
st.subheader("不同情感的平均响应时间")
|
||||||
avg_resp = filtered_df.group_by("sentiment").agg(pl.col("response_time_hours").mean()).sort("response_time_hours", descending=True)
|
avg_resp = filtered_df.group_by("sentiment").agg(pl.col("response_time_hours").mean()).sort("response_time_hours", descending=True)
|
||||||
avg_resp_chart = (
|
avg_resp_chart = (
|
||||||
alt.Chart(avg_resp.to_pandas())
|
alt.Chart(avg_resp.to_pandas())
|
||||||
.mark_line(point=True)
|
.mark_line(point=True)
|
||||||
.encode(
|
.encode(
|
||||||
x=alt.X("sentiment:N", title="Sentiment"),
|
x=alt.X("sentiment:N", title="情感"),
|
||||||
y=alt.Y("response_time_hours:Q", title="Avg Response Time (hours)"),
|
y=alt.Y("response_time_hours:Q", title="平均响应时间(小时)"),
|
||||||
color=alt.value("#ffb74d"),
|
color=alt.value("#ffb74d"),
|
||||||
tooltip=["sentiment", "response_time_hours"],
|
tooltip=["sentiment", "response_time_hours"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
st.altair_chart(avg_resp_chart, use_container_width=True)
|
st.altair_chart(avg_resp_chart, use_container_width=True)
|
||||||
|
|
||||||
st.subheader("Top Platforms")
|
st.subheader("主要平台分布")
|
||||||
top_platforms = filtered_df["platform"].value_counts().sort("count", descending=True).head(10)
|
top_platforms = filtered_df["platform"].value_counts().sort("count", descending=True).head(10)
|
||||||
platform_chart = (
|
platform_chart = (
|
||||||
alt.Chart(top_platforms.to_pandas())
|
alt.Chart(top_platforms.to_pandas())
|
||||||
.mark_bar(cornerRadiusTopLeft=6, cornerRadiusTopRight=6)
|
.mark_bar(cornerRadiusTopLeft=6, cornerRadiusTopRight=6)
|
||||||
.encode(
|
.encode(
|
||||||
x=alt.X("platform:N", sort="-y", title="Platform"),
|
x=alt.X("platform:N", sort="-y", title="平台"),
|
||||||
y=alt.Y("count:Q", title="Count"),
|
y=alt.Y("count:Q", title="数量"),
|
||||||
color=alt.value("#ce93d8"),
|
color=alt.value("#ce93d8"),
|
||||||
tooltip=["platform", "count"],
|
tooltip=["platform", "count"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
st.altair_chart(platform_chart, use_container_width=True)
|
st.altair_chart(platform_chart, use_container_width=True)
|
||||||
|
|
||||||
st.subheader("Top Product Categories")
|
st.subheader("主要产品类别分布")
|
||||||
top_products = filtered_df["product_category"].value_counts().sort("count", descending=True).head(10)
|
top_products = filtered_df["product_category"].value_counts().sort("count", descending=True).head(10)
|
||||||
product_chart = (
|
product_chart = (
|
||||||
alt.Chart(top_products.to_pandas())
|
alt.Chart(top_products.to_pandas())
|
||||||
.mark_bar(cornerRadiusTopLeft=6, cornerRadiusTopRight=6)
|
.mark_bar(cornerRadiusTopLeft=6, cornerRadiusTopRight=6)
|
||||||
.encode(
|
.encode(
|
||||||
x=alt.X("product_category:N", sort="-y", title="Product Category"),
|
x=alt.X("product_category:N", sort="-y", title="产品类别"),
|
||||||
y=alt.Y("count:Q", title="Count"),
|
y=alt.Y("count:Q", title="数量"),
|
||||||
color=alt.value("#ffcc80"),
|
color=alt.value("#ffcc80"),
|
||||||
tooltip=["product_category", "count"],
|
tooltip=["product_category", "count"],
|
||||||
)
|
)
|
||||||
@ -316,6 +316,15 @@ if df is not None:
|
|||||||
st.progress(risk)
|
st.progress(risk)
|
||||||
|
|
||||||
st.markdown("#### 分析")
|
st.markdown("#### 分析")
|
||||||
|
# Get API key for LLM
|
||||||
|
api_key = os.getenv("API_KEY")
|
||||||
|
if api_key:
|
||||||
|
# Use LLM to generate natural language explanation
|
||||||
|
llm_explanation = explain_features_with_llm(customer, api_key)
|
||||||
|
st.write(llm_explanation)
|
||||||
|
else:
|
||||||
|
# Fallback to original explanation if no API key
|
||||||
|
st.warning("未找到API密钥,使用原始分析结果")
|
||||||
for item in explanations:
|
for item in explanations:
|
||||||
st.write(f"- {item}")
|
st.write(f"- {item}")
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user