2026-01-09 14:30:23 +08:00
|
|
|
|
"""Streamlit 演示应用
|
|
|
|
|
|
|
|
|
|
|
|
学生成绩预测 AI 助手 - 支持成绩预测分析和心理咨询对话。
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
2026-01-01 11:19:17 +08:00
|
|
|
|
import asyncio
|
|
|
|
|
|
import os
|
|
|
|
|
|
import sys
|
|
|
|
|
|
|
2026-01-09 14:30:23 +08:00
|
|
|
|
import streamlit as st
|
|
|
|
|
|
|
2026-01-01 11:19:17 +08:00
|
|
|
|
# Ensure project root is in path
|
|
|
|
|
|
sys.path.append(os.getcwd())
|
|
|
|
|
|
|
2026-01-09 14:30:23 +08:00
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
|
|
from pydantic_ai import FunctionToolCallEvent, FunctionToolResultEvent, PartDeltaEvent
|
2026-01-01 11:19:17 +08:00
|
|
|
|
from pydantic_ai.messages import (
|
2026-01-09 14:30:23 +08:00
|
|
|
|
ModelRequest,
|
|
|
|
|
|
ModelResponse,
|
|
|
|
|
|
TextPart,
|
|
|
|
|
|
TextPartDelta,
|
|
|
|
|
|
UserPromptPart,
|
2026-01-01 11:19:17 +08:00
|
|
|
|
)
|
2026-01-09 14:30:23 +08:00
|
|
|
|
|
|
|
|
|
|
from src.agent_app import AgentDeps, counselor_agent, study_advisor
|
|
|
|
|
|
from src.features import StudentFeatures
|
2026-01-01 11:19:17 +08:00
|
|
|
|
|
|
|
|
|
|
# Load env variables
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
|
2026-01-09 14:30:23 +08:00
|
|
|
|
st.set_page_config(page_title="学生成绩预测 AI 助手", page_icon="🎓", layout="wide")
|
2026-01-01 11:19:17 +08:00
|
|
|
|
|
|
|
|
|
|
# Sidebar Configuration
|
|
|
|
|
|
st.sidebar.header("🔧 配置")
|
2026-01-09 14:30:23 +08:00
|
|
|
|
api_key = st.sidebar.text_input(
|
|
|
|
|
|
"DeepSeek API Key", type="password", value=os.getenv("DEEPSEEK_API_KEY", "")
|
|
|
|
|
|
)
|
2026-01-01 11:19:17 +08:00
|
|
|
|
|
|
|
|
|
|
if api_key:
|
|
|
|
|
|
os.environ["DEEPSEEK_API_KEY"] = api_key
|
|
|
|
|
|
|
|
|
|
|
|
st.sidebar.markdown("---")
|
|
|
|
|
|
# Mode Selection
|
|
|
|
|
|
mode = st.sidebar.radio("功能选择", ["📊 成绩预测", "💬 心理咨询"])
|
|
|
|
|
|
|
|
|
|
|
|
# --- Helper Functions ---
|
|
|
|
|
|
|
2026-01-09 14:30:23 +08:00
|
|
|
|
|
|
|
|
|
|
async def run_analysis(
|
|
|
|
|
|
study_hours: float,
|
|
|
|
|
|
sleep_hours: float,
|
|
|
|
|
|
attendance_rate: float,
|
|
|
|
|
|
stress_level: int,
|
|
|
|
|
|
study_type: str,
|
|
|
|
|
|
):
|
|
|
|
|
|
"""运行成绩预测分析"""
|
2026-01-01 11:19:17 +08:00
|
|
|
|
try:
|
|
|
|
|
|
if not os.getenv("DEEPSEEK_API_KEY"):
|
|
|
|
|
|
st.error("请在侧边栏提供 DeepSeek API Key。")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-09 14:30:23 +08:00
|
|
|
|
# 创建学生特征
|
|
|
|
|
|
student = StudentFeatures(
|
|
|
|
|
|
study_hours=study_hours,
|
|
|
|
|
|
sleep_hours=sleep_hours,
|
|
|
|
|
|
attendance_rate=attendance_rate,
|
|
|
|
|
|
stress_level=stress_level,
|
|
|
|
|
|
study_type=study_type,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 创建依赖
|
|
|
|
|
|
deps = AgentDeps(student=student)
|
|
|
|
|
|
|
|
|
|
|
|
# 构建查询
|
|
|
|
|
|
query = (
|
|
|
|
|
|
f"请分析这位学生的通过率并给出建议。"
|
|
|
|
|
|
f"学生信息已通过工具获取,请调用 predict_pass_probability 和 get_model_explanation 工具。"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-01 11:19:17 +08:00
|
|
|
|
with st.spinner("🤖 Agent 正在思考... (调用 DeepSeek + 随机森林模型)"):
|
2026-01-09 14:30:23 +08:00
|
|
|
|
result = await study_advisor.run(query, deps=deps)
|
|
|
|
|
|
return result.output
|
|
|
|
|
|
|
2026-01-01 11:19:17 +08:00
|
|
|
|
except Exception as e:
|
2026-01-09 14:30:23 +08:00
|
|
|
|
st.error(f"分析失败: {e!s}")
|
2026-01-01 11:19:17 +08:00
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-09 14:30:23 +08:00
|
|
|
|
|
|
|
|
|
|
async def run_counselor_stream(
|
|
|
|
|
|
query: str,
|
|
|
|
|
|
history: list,
|
|
|
|
|
|
placeholder,
|
|
|
|
|
|
student: StudentFeatures,
|
|
|
|
|
|
):
|
2026-01-01 11:19:17 +08:00
|
|
|
|
"""
|
2026-01-09 14:30:23 +08:00
|
|
|
|
运行咨询师对话流,手动处理流式响应和工具调用事件。
|
2026-01-01 11:19:17 +08:00
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
if not os.getenv("DEEPSEEK_API_KEY"):
|
|
|
|
|
|
placeholder.error("❌ 错误: 请在侧边栏提供 DeepSeek API Key。")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-09 14:30:23 +08:00
|
|
|
|
# 创建依赖
|
|
|
|
|
|
deps = AgentDeps(student=student)
|
|
|
|
|
|
|
2026-01-01 11:19:17 +08:00
|
|
|
|
full_response = ""
|
|
|
|
|
|
# Status container for tool calls
|
|
|
|
|
|
status_placeholder = st.empty()
|
2026-01-09 14:30:23 +08:00
|
|
|
|
|
2026-01-01 11:19:17 +08:00
|
|
|
|
# Call Counselor Agent with streaming
|
2026-01-09 14:30:23 +08:00
|
|
|
|
async for event in counselor_agent.run_stream_events(query, deps=deps, message_history=history):
|
2026-01-01 11:19:17 +08:00
|
|
|
|
# Handle Text Delta (Wrapped in PartDeltaEvent)
|
|
|
|
|
|
if isinstance(event, PartDeltaEvent) and isinstance(event.delta, TextPartDelta):
|
|
|
|
|
|
full_response += event.delta.content_delta
|
|
|
|
|
|
placeholder.markdown(full_response + "▌")
|
2026-01-09 14:30:23 +08:00
|
|
|
|
|
2026-01-01 11:19:17 +08:00
|
|
|
|
# Handle Tool Call Start
|
|
|
|
|
|
elif isinstance(event, FunctionToolCallEvent):
|
|
|
|
|
|
status_placeholder.info(f"🛠️ 咨询师正在使用工具: `{event.part.tool_name}` ...")
|
2026-01-09 14:30:23 +08:00
|
|
|
|
|
2026-01-01 11:19:17 +08:00
|
|
|
|
# Handle Tool Result
|
|
|
|
|
|
elif isinstance(event, FunctionToolResultEvent):
|
|
|
|
|
|
status_placeholder.empty()
|
|
|
|
|
|
|
|
|
|
|
|
placeholder.markdown(full_response)
|
2026-01-09 14:30:23 +08:00
|
|
|
|
status_placeholder.empty() # Ensure clear
|
2026-01-01 11:19:17 +08:00
|
|
|
|
return full_response
|
2026-01-09 14:30:23 +08:00
|
|
|
|
|
2026-01-01 11:19:17 +08:00
|
|
|
|
except Exception as e:
|
2026-01-09 14:30:23 +08:00
|
|
|
|
placeholder.error(f"❌ 咨询失败: {e!s}")
|
2026-01-01 11:19:17 +08:00
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-09 14:30:23 +08:00
|
|
|
|
|
2026-01-01 11:19:17 +08:00
|
|
|
|
# --- Main Views ---
|
|
|
|
|
|
|
|
|
|
|
|
if mode == "📊 成绩预测":
|
|
|
|
|
|
st.title("🎓 学生成绩预测助手")
|
|
|
|
|
|
st.markdown("在下方输入学生详细信息,获取 AI 驱动的成绩分析。")
|
|
|
|
|
|
|
|
|
|
|
|
with st.form("student_data_form"):
|
|
|
|
|
|
col1, col2 = st.columns(2)
|
2026-01-09 14:30:23 +08:00
|
|
|
|
|
2026-01-01 11:19:17 +08:00
|
|
|
|
with col1:
|
|
|
|
|
|
study_hours = st.slider("每周学习时长 (小时)", 0.0, 20.0, 10.0, 0.5)
|
|
|
|
|
|
sleep_hours = st.slider("日均睡眠时长 (小时)", 0.0, 12.0, 7.0, 0.5)
|
2026-01-09 14:30:23 +08:00
|
|
|
|
|
2026-01-01 11:19:17 +08:00
|
|
|
|
with col2:
|
|
|
|
|
|
attendance_rate = st.slider("出勤率", 0.0, 1.0, 0.9, 0.05)
|
2026-01-09 14:30:23 +08:00
|
|
|
|
stress_level = st.select_slider(
|
|
|
|
|
|
"压力等级 (1=低, 5=高)", options=[1, 2, 3, 4, 5], value=3
|
|
|
|
|
|
)
|
2026-01-01 11:19:17 +08:00
|
|
|
|
study_type = st.radio("主要学习方式", ["Self", "Group", "Online"], horizontal=True)
|
|
|
|
|
|
|
|
|
|
|
|
submitted = st.form_submit_button("🚀 分析通过率")
|
|
|
|
|
|
|
|
|
|
|
|
if submitted:
|
|
|
|
|
|
loop = asyncio.new_event_loop()
|
|
|
|
|
|
asyncio.set_event_loop(loop)
|
2026-01-09 14:30:23 +08:00
|
|
|
|
guidance = loop.run_until_complete(
|
|
|
|
|
|
run_analysis(study_hours, sleep_hours, attendance_rate, stress_level, study_type)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-01 11:19:17 +08:00
|
|
|
|
if guidance:
|
|
|
|
|
|
st.divider()
|
|
|
|
|
|
st.subheader("📊 分析结果")
|
|
|
|
|
|
m1, m2, m3 = st.columns(3)
|
|
|
|
|
|
m1.metric("预测通过率", f"{guidance.pass_probability:.1%}")
|
2026-01-09 14:30:23 +08:00
|
|
|
|
m2.metric(
|
|
|
|
|
|
"风险评估",
|
|
|
|
|
|
"高风险" if guidance.pass_probability < 0.6 else "低风险",
|
|
|
|
|
|
delta="-高风险" if guidance.pass_probability < 0.6 else "+安全",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-01 11:19:17 +08:00
|
|
|
|
st.info(f"**风险评估:** {guidance.risk_assessment}")
|
|
|
|
|
|
|
2026-01-09 14:30:23 +08:00
|
|
|
|
# 显示关键因素
|
|
|
|
|
|
st.subheader("🔍 关键因素")
|
|
|
|
|
|
for factor in guidance.key_factors:
|
|
|
|
|
|
st.write(f"- {factor}")
|
|
|
|
|
|
|
2026-01-01 11:19:17 +08:00
|
|
|
|
st.subheader("✅ 行动计划")
|
2026-01-09 14:30:23 +08:00
|
|
|
|
actions = [
|
|
|
|
|
|
{"优先级": item.priority, "建议行动": item.action} for item in guidance.action_plan
|
|
|
|
|
|
]
|
2026-01-01 11:19:17 +08:00
|
|
|
|
st.table(actions)
|
2026-01-09 14:30:23 +08:00
|
|
|
|
|
|
|
|
|
|
st.subheader("💡 分析依据")
|
|
|
|
|
|
st.write(guidance.rationale)
|
2026-01-01 11:19:17 +08:00
|
|
|
|
|
|
|
|
|
|
elif mode == "💬 心理咨询":
|
|
|
|
|
|
st.title("🧩 AI 心理咨询室")
|
|
|
|
|
|
st.markdown("这里是安全且私密的空间。有些压力如果你愿意说,我愿意听。")
|
|
|
|
|
|
|
2026-01-09 14:30:23 +08:00
|
|
|
|
# Sidebar for student info (optional for counselor context)
|
|
|
|
|
|
with st.sidebar.expander("📝 学生信息 (可选)", expanded=False):
|
|
|
|
|
|
c_study_hours = st.slider("每周学习时长", 0.0, 20.0, 10.0, 0.5, key="c_study")
|
|
|
|
|
|
c_sleep_hours = st.slider("日均睡眠时长", 0.0, 12.0, 7.0, 0.5, key="c_sleep")
|
|
|
|
|
|
c_attendance = st.slider("出勤率", 0.0, 1.0, 0.9, 0.05, key="c_att")
|
|
|
|
|
|
c_stress = st.select_slider("压力等级", options=[1, 2, 3, 4, 5], value=3, key="c_stress")
|
|
|
|
|
|
c_study_type = st.radio("学习方式", ["Self", "Group", "Online"], key="c_type")
|
|
|
|
|
|
|
2026-01-01 11:19:17 +08:00
|
|
|
|
# Initialize chat history
|
|
|
|
|
|
if "messages" not in st.session_state:
|
|
|
|
|
|
st.session_state.messages = []
|
|
|
|
|
|
|
|
|
|
|
|
# Display chat messages from history on app rerun
|
|
|
|
|
|
for message in st.session_state.messages:
|
|
|
|
|
|
with st.chat_message(message["role"]):
|
|
|
|
|
|
st.markdown(message["content"])
|
|
|
|
|
|
|
|
|
|
|
|
# React to user input
|
|
|
|
|
|
if prompt := st.chat_input("想聊聊什么?"):
|
|
|
|
|
|
# Display user message
|
|
|
|
|
|
with st.chat_message("user"):
|
|
|
|
|
|
st.markdown(prompt)
|
|
|
|
|
|
# Add user message to history
|
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
|
|
|
|
|
|
|
|
|
|
|
# Prepare history for pydantic-ai
|
|
|
|
|
|
api_history = []
|
|
|
|
|
|
for msg in st.session_state.messages[:-1]:
|
|
|
|
|
|
if msg["role"] == "user":
|
|
|
|
|
|
api_history.append(ModelRequest(parts=[UserPromptPart(content=msg["content"])]))
|
|
|
|
|
|
elif msg["role"] == "assistant":
|
|
|
|
|
|
api_history.append(ModelResponse(parts=[TextPart(content=msg["content"])]))
|
|
|
|
|
|
|
2026-01-09 14:30:23 +08:00
|
|
|
|
# Create student features for counselor context
|
|
|
|
|
|
student = StudentFeatures(
|
|
|
|
|
|
study_hours=c_study_hours,
|
|
|
|
|
|
sleep_hours=c_sleep_hours,
|
|
|
|
|
|
attendance_rate=c_attendance,
|
|
|
|
|
|
stress_level=c_stress,
|
|
|
|
|
|
study_type=c_study_type,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-01 11:19:17 +08:00
|
|
|
|
# Generate response
|
|
|
|
|
|
with st.chat_message("assistant"):
|
|
|
|
|
|
placeholder = st.empty()
|
|
|
|
|
|
with st.spinner("咨询师正在倾听..."):
|
|
|
|
|
|
loop = asyncio.new_event_loop()
|
|
|
|
|
|
asyncio.set_event_loop(loop)
|
|
|
|
|
|
# Run the manual streaming function
|
2026-01-09 14:30:23 +08:00
|
|
|
|
response_text = loop.run_until_complete(
|
|
|
|
|
|
run_counselor_stream(prompt, api_history, placeholder, student)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-01 11:19:17 +08:00
|
|
|
|
if response_text:
|
2026-01-09 14:30:23 +08:00
|
|
|
|
st.session_state.messages.append(
|
|
|
|
|
|
{"role": "assistant", "content": response_text}
|
|
|
|
|
|
)
|