CourseDesign/src/streamlit_app.py

186 lines
7.3 KiB
Python
Raw Normal View History

import streamlit as st
import asyncio
import os
import sys
# Ensure project root is in path
sys.path.append(os.getcwd())
from src.agent_app import agent, counselor_agent, StudyGuidance
from pydantic_ai.messages import (
ModelMessage, ModelRequest, ModelResponse, UserPromptPart, TextPart,
TextPartDelta, ToolCallPart, ToolReturnPart
)
from pydantic_ai import (
AgentStreamEvent, PartDeltaEvent, FunctionToolCallEvent, FunctionToolResultEvent
)
from dotenv import load_dotenv
# Load env variables
load_dotenv()
st.set_page_config(
page_title="学生成绩预测 AI 助手",
page_icon="🎓",
layout="wide"
)
# Sidebar Configuration
st.sidebar.header("🔧 配置")
api_key = st.sidebar.text_input("DeepSeek API Key", type="password", value=os.getenv("DEEPSEEK_API_KEY", ""))
if api_key:
os.environ["DEEPSEEK_API_KEY"] = api_key
st.sidebar.markdown("---")
# Mode Selection
mode = st.sidebar.radio("功能选择", ["📊 成绩预测", "💬 心理咨询"])
# --- Helper Functions ---
async def run_analysis(query):
try:
if not os.getenv("DEEPSEEK_API_KEY"):
st.error("请在侧边栏提供 DeepSeek API Key。")
return None
with st.spinner("🤖 Agent 正在思考... (调用 DeepSeek + 随机森林模型)"):
result = await agent.run(query)
return result.output
except Exception as e:
st.error(f"分析失败: {str(e)}")
return None
async def run_counselor_stream(query, history, placeholder):
"""
Manually stream the response to a placeholder, handling tool events for visibility.
"""
try:
if not os.getenv("DEEPSEEK_API_KEY"):
placeholder.error("❌ 错误: 请在侧边栏提供 DeepSeek API Key。")
return None
full_response = ""
# Status container for tool calls
status_placeholder = st.empty()
# Call Counselor Agent with streaming
# Call Counselor Agent with streaming using run_stream_events which is the modern way to get events
async for event in counselor_agent.run_stream_events(query, message_history=history):
# Handle Text Delta (Wrapped in PartDeltaEvent)
if isinstance(event, PartDeltaEvent) and isinstance(event.delta, TextPartDelta):
full_response += event.delta.content_delta
placeholder.markdown(full_response + "")
# Handle Tool Call Start
elif isinstance(event, FunctionToolCallEvent):
# event.part is ToolCallPart usually, or event.tool_call
# Check pydantic-ai docs structure: FunctionToolCallEvent has 'part' which is ToolCallPart
status_placeholder.info(f"🛠️ 咨询师正在使用工具: `{event.part.tool_name}` ...")
# Handle Tool Result
elif isinstance(event, FunctionToolResultEvent):
status_placeholder.empty()
placeholder.markdown(full_response)
status_placeholder.empty() # Ensure clear
return full_response
except Exception as e:
placeholder.error(f"❌ 咨询失败: {str(e)}")
return None
# --- Main Views ---
if mode == "📊 成绩预测":
st.title("🎓 学生成绩预测助手")
st.markdown("在下方输入学生详细信息,获取 AI 驱动的成绩分析。")
with st.form("student_data_form"):
col1, col2 = st.columns(2)
with col1:
study_hours = st.slider("每周学习时长 (小时)", 0.0, 20.0, 10.0, 0.5)
sleep_hours = st.slider("日均睡眠时长 (小时)", 0.0, 12.0, 7.0, 0.5)
with col2:
attendance_rate = st.slider("出勤率", 0.0, 1.0, 0.9, 0.05)
stress_level = st.select_slider("压力等级 (1=低, 5=高)", options=[1, 2, 3, 4, 5], value=3)
study_type = st.radio("主要学习方式", ["Self", "Group", "Online"], horizontal=True)
submitted = st.form_submit_button("🚀 分析通过率")
if submitted:
user_query = (
f"我是一名学生,情况如下:"
f"每周学习时间: {study_hours} 小时;"
f"平均睡眠时间: {sleep_hours} 小时;"
f"出勤率: {attendance_rate:.2f}"
f"压力等级: {stress_level} (1-5)"
f"主要学习方式: {study_type}"
f"请调用 `predict_student` 预测我的通过率,并调用 `explain_model` 分析关键因素,最后给出针对性的建议。"
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
guidance = loop.run_until_complete(run_analysis(user_query))
if guidance:
st.divider()
st.subheader("📊 分析结果")
m1, m2, m3 = st.columns(3)
m1.metric("预测通过率", f"{guidance.pass_probability:.1%}")
m2.metric("风险评估", "高风险" if guidance.pass_probability < 0.6 else "低风险",
delta="-高风险" if guidance.pass_probability < 0.6 else "+安全")
st.info(f"**风险评估:** {guidance.risk_assessment}")
st.write(f"**关键因素:** {guidance.key_drivers}")
st.subheader("✅ 行动计划")
actions = [{"优先级": item.priority, "建议行动": item.action} for item in guidance.action_plan]
st.table(actions)
elif mode == "💬 心理咨询":
st.title("🧩 AI 心理咨询室")
st.markdown("这里是安全且私密的空间。有些压力如果你愿意说,我愿意听。")
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# React to user input
if prompt := st.chat_input("想聊聊什么?"):
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Add user message to history
st.session_state.messages.append({"role": "user", "content": prompt})
# Prepare history for pydantic-ai
# Convert Streamlit history to pydantic-ai ModelMessages
# Note: We exclude the last message because `agent.run` takes the new prompt as argument
api_history = []
for msg in st.session_state.messages[:-1]:
if msg["role"] == "user":
api_history.append(ModelRequest(parts=[UserPromptPart(content=msg["content"])]))
elif msg["role"] == "assistant":
api_history.append(ModelResponse(parts=[TextPart(content=msg["content"])]))
# Generate response
with st.chat_message("assistant"):
placeholder = st.empty()
with st.spinner("咨询师正在倾听..."):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Run the manual streaming function
response_text = loop.run_until_complete(run_counselor_stream(prompt, api_history, placeholder))
if response_text:
st.session_state.messages.append({"role": "assistant", "content": response_text})