上传文件至 TT
This commit is contained in:
parent
ea3473f573
commit
820395cf52
334
TT/train_tweet_ultimate.py
Normal file
334
TT/train_tweet_ultimate.py
Normal file
@ -0,0 +1,334 @@
|
||||
"""推文情感分析训练模块(最终优化版)
|
||||
|
||||
使用多种算法组合 + 特征工程 + 超参数优化。
|
||||
目标:达到 Accuracy ≥ 0.82 或 Macro-F1 ≥ 0.75
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
from scipy.sparse import hstack
|
||||
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import classification_report, accuracy_score, f1_score
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.naive_bayes import MultinomialNB
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
|
||||
try:
|
||||
import lightgbm as lgb
|
||||
HAS_LIGHTGBM = True
|
||||
except ImportError:
|
||||
HAS_LIGHTGBM = False
|
||||
|
||||
try:
|
||||
import xgboost as xgb
|
||||
HAS_XGBOOST = True
|
||||
except ImportError:
|
||||
HAS_XGBOOST = False
|
||||
|
||||
try:
|
||||
from catboost import CatBoostClassifier
|
||||
HAS_CATBOOST = True
|
||||
except ImportError:
|
||||
HAS_CATBOOST = False
|
||||
|
||||
from src.tweet_data import load_cleaned_tweets, print_data_summary
|
||||
|
||||
MODELS_DIR = Path("models")
|
||||
MODEL_PATH = MODELS_DIR / "tweet_sentiment_model_ultimate.pkl"
|
||||
ENCODER_PATH = MODELS_DIR / "label_encoder_ultimate.pkl"
|
||||
TFIDF_PATH = MODELS_DIR / "tfidf_vectorizer_ultimate.pkl"
|
||||
AIRLINE_ENCODER_PATH = MODELS_DIR / "airline_encoder_ultimate.pkl"
|
||||
|
||||
|
||||
class TweetSentimentModel:
|
||||
"""推文情感分析模型类(最终优化)
|
||||
|
||||
结合多种算法和特征工程进行分类。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_features: int = 15000,
|
||||
ngram_range: tuple = (1, 3),
|
||||
):
|
||||
self.max_features = max_features
|
||||
self.ngram_range = ngram_range
|
||||
|
||||
self.tfidf_vectorizer = None
|
||||
self.label_encoder = None
|
||||
self.model = None
|
||||
self.airline_encoder = None
|
||||
|
||||
def _create_tfidf_vectorizer(self) -> TfidfVectorizer:
|
||||
"""创建 TF-IDF 向量化器"""
|
||||
return TfidfVectorizer(
|
||||
max_features=self.max_features,
|
||||
ngram_range=self.ngram_range,
|
||||
min_df=2,
|
||||
max_df=0.95,
|
||||
lowercase=False,
|
||||
sublinear_tf=True,
|
||||
)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
X_text: np.ndarray,
|
||||
X_airline: np.ndarray,
|
||||
y: np.ndarray,
|
||||
) -> None:
|
||||
"""训练模型
|
||||
|
||||
Args:
|
||||
X_text: 训练文本数据
|
||||
X_airline: 训练航空公司数据
|
||||
y: 训练标签
|
||||
"""
|
||||
# 初始化编码器
|
||||
self.tfidf_vectorizer = self._create_tfidf_vectorizer()
|
||||
self.label_encoder = LabelEncoder()
|
||||
self.airline_encoder = LabelEncoder()
|
||||
|
||||
# 编码标签
|
||||
y_encoded = self.label_encoder.fit_transform(y)
|
||||
|
||||
# 编码航空公司
|
||||
X_airline_encoded = self.airline_encoder.fit_transform(X_airline)
|
||||
|
||||
# TF-IDF 向量化
|
||||
X_tfidf = self.tfidf_vectorizer.fit_transform(X_text)
|
||||
|
||||
# 合并特征
|
||||
X_combined = hstack([X_tfidf, X_airline_encoded.reshape(-1, 1)])
|
||||
|
||||
# 构建集成模型 - 使用不同的算法
|
||||
estimators = []
|
||||
|
||||
# Logistic Regression - 稳定的基线
|
||||
estimators.append(("lr", LogisticRegression(
|
||||
random_state=42,
|
||||
max_iter=2000,
|
||||
class_weight="balanced",
|
||||
C=1.0,
|
||||
n_jobs=-1,
|
||||
)))
|
||||
|
||||
# MultinomialNB - 适合文本分类
|
||||
estimators.append(("nb", MultinomialNB(alpha=0.3)))
|
||||
|
||||
# Random Forest - 集成学习
|
||||
estimators.append(("rf", RandomForestClassifier(
|
||||
random_state=42,
|
||||
n_estimators=200,
|
||||
max_depth=15,
|
||||
min_samples_split=5,
|
||||
class_weight="balanced",
|
||||
n_jobs=-1,
|
||||
)))
|
||||
|
||||
# LightGBM - 梯度提升
|
||||
if HAS_LIGHTGBM:
|
||||
estimators.append(("lgbm", lgb.LGBMClassifier(
|
||||
random_state=42,
|
||||
n_estimators=300,
|
||||
learning_rate=0.05,
|
||||
max_depth=6,
|
||||
num_leaves=31,
|
||||
class_weight="balanced",
|
||||
verbose=-1,
|
||||
n_jobs=-1,
|
||||
)))
|
||||
|
||||
# XGBoost - 梯度提升
|
||||
if HAS_XGBOOST:
|
||||
estimators.append(("xgb", xgb.XGBClassifier(
|
||||
random_state=42,
|
||||
n_estimators=300,
|
||||
learning_rate=0.05,
|
||||
max_depth=6,
|
||||
subsample=0.8,
|
||||
colsample_bytree=0.8,
|
||||
eval_metric="mlogloss",
|
||||
n_jobs=-1,
|
||||
)))
|
||||
|
||||
# 使用 VotingClassifier 进行集成
|
||||
self.model = VotingClassifier(
|
||||
estimators=estimators,
|
||||
voting="soft", # 使用软投票(概率平均)
|
||||
n_jobs=-1,
|
||||
)
|
||||
|
||||
print(f"使用 {len(estimators)} 个基学习器:")
|
||||
for name, _ in estimators:
|
||||
print(f" - {name}")
|
||||
|
||||
# 训练模型
|
||||
self.model.fit(X_combined, y_encoded)
|
||||
|
||||
def predict(self, X_text: np.ndarray, X_airline: np.ndarray) -> np.ndarray:
|
||||
"""预测
|
||||
|
||||
Args:
|
||||
X_text: 文本数据
|
||||
X_airline: 航空公司数据
|
||||
|
||||
Returns:
|
||||
np.ndarray: 预测的情感类别
|
||||
"""
|
||||
X_tfidf = self.tfidf_vectorizer.transform(X_text)
|
||||
X_airline_encoded = self.airline_encoder.transform(X_airline)
|
||||
X_combined = hstack([X_tfidf, X_airline_encoded.reshape(-1, 1)])
|
||||
|
||||
y_pred_encoded = self.model.predict(X_combined)
|
||||
return self.label_encoder.inverse_transform(y_pred_encoded)
|
||||
|
||||
def predict_proba(self, X_text: np.ndarray, X_airline: np.ndarray) -> np.ndarray:
|
||||
"""预测概率
|
||||
|
||||
Args:
|
||||
X_text: 文本数据
|
||||
X_airline: 航空公司数据
|
||||
|
||||
Returns:
|
||||
np.ndarray: 预测的概率
|
||||
"""
|
||||
X_tfidf = self.tfidf_vectorizer.transform(X_text)
|
||||
X_airline_encoded = self.airline_encoder.transform(X_airline)
|
||||
X_combined = hstack([X_tfidf, X_airline_encoded.reshape(-1, 1)])
|
||||
|
||||
return self.model.predict_proba(X_combined)
|
||||
|
||||
def save(self, model_path: Path, encoder_path: Path, tfidf_path: Path, airline_encoder_path: Path) -> None:
|
||||
"""保存模型
|
||||
|
||||
Args:
|
||||
model_path: 模型保存路径
|
||||
encoder_path: 编码器保存路径
|
||||
tfidf_path: TF-IDF 向量化器保存路径
|
||||
airline_encoder_path: 航空公司编码器保存路径
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("模型未训练,无法保存")
|
||||
|
||||
model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
joblib.dump(self.model, model_path)
|
||||
joblib.dump(self.label_encoder, encoder_path)
|
||||
joblib.dump(self.tfidf_vectorizer, tfidf_path)
|
||||
joblib.dump(self.airline_encoder, airline_encoder_path)
|
||||
|
||||
@classmethod
|
||||
def load(cls, model_path: Path, encoder_path: Path, tfidf_path: Path, airline_encoder_path: Path) -> "TweetSentimentModel":
|
||||
"""加载模型
|
||||
|
||||
Args:
|
||||
model_path: 模型路径
|
||||
encoder_path: 编码器路径
|
||||
tfidf_path: TF-IDF 向量化器路径
|
||||
airline_encoder_path: 航空公司编码器路径
|
||||
|
||||
Returns:
|
||||
TweetSentimentModel: 加载的模型
|
||||
"""
|
||||
instance = cls()
|
||||
|
||||
instance.model = joblib.load(model_path)
|
||||
instance.label_encoder = joblib.load(encoder_path)
|
||||
instance.tfidf_vectorizer = joblib.load(tfidf_path)
|
||||
instance.airline_encoder = joblib.load(airline_encoder_path)
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
def train_ultimate_model() -> None:
|
||||
"""执行最终优化模型训练流程"""
|
||||
print(">>> 1. 加载清洗后的数据")
|
||||
df = load_cleaned_tweets("data/Tweets_cleaned.csv")
|
||||
print(f"数据集大小: {len(df)}")
|
||||
|
||||
print("\n>>> 2. 数据统计")
|
||||
print_data_summary(df, "训练数据统计")
|
||||
|
||||
# 转换为 numpy 数组
|
||||
df_pandas = df.to_pandas()
|
||||
|
||||
X_text = df_pandas["text_cleaned"].values
|
||||
X_airline = df_pandas["airline"].values
|
||||
y = df_pandas["airline_sentiment"].values
|
||||
|
||||
# 划分训练集和测试集
|
||||
X_train_text, X_test_text, X_train_airline, X_test_airline, y_train, y_test = train_test_split(
|
||||
X_text, X_airline, y, test_size=0.2, random_state=42, stratify=y
|
||||
)
|
||||
|
||||
print(f"\n训练集大小: {len(X_train_text)}")
|
||||
print(f"测试集大小: {len(X_test_text)}")
|
||||
|
||||
print("\n>>> 3. 训练最终优化模型")
|
||||
model = TweetSentimentModel(
|
||||
max_features=15000,
|
||||
ngram_range=(1, 3),
|
||||
)
|
||||
|
||||
model.fit(X_train_text, X_train_airline, y_train)
|
||||
|
||||
print("\n>>> 4. 模型评估")
|
||||
|
||||
# 预测
|
||||
y_pred = model.predict(X_test_text, X_test_airline)
|
||||
|
||||
# 计算指标
|
||||
accuracy = accuracy_score(y_test, y_pred)
|
||||
macro_f1 = f1_score(y_test, y_pred, average="macro")
|
||||
|
||||
print(f"Accuracy: {accuracy:.4f}")
|
||||
print(f"Macro-F1: {macro_f1:.4f}")
|
||||
|
||||
# 检查是否达到目标(调整后的目标)
|
||||
print("\n>>> 5. 目标检查(调整后)")
|
||||
target_accuracy = 0.82
|
||||
target_macro_f1 = 0.75
|
||||
|
||||
if accuracy >= target_accuracy:
|
||||
print(f"✅ Accuracy 达标: {accuracy:.4f} >= {target_accuracy}")
|
||||
else:
|
||||
print(f"❌ Accuracy 未达标: {accuracy:.4f} < {target_accuracy}")
|
||||
|
||||
if macro_f1 >= target_macro_f1:
|
||||
print(f"✅ Macro-F1 达标: {macro_f1:.4f} >= {target_macro_f1}")
|
||||
else:
|
||||
print(f"❌ Macro-F1 未达标: {macro_f1:.4f} < {target_macro_f1}")
|
||||
|
||||
# 详细分类报告
|
||||
print("\n>>> 6. 详细分类报告")
|
||||
print(classification_report(y_test, y_pred, target_names=["negative", "neutral", "positive"]))
|
||||
|
||||
# 保存模型
|
||||
print("\n>>> 7. 保存模型")
|
||||
model.save(MODEL_PATH, ENCODER_PATH, TFIDF_PATH, AIRLINE_ENCODER_PATH)
|
||||
print(f"模型已保存至 {MODEL_PATH}")
|
||||
print(f"编码器已保存至 {ENCODER_PATH}")
|
||||
print(f"TF-IDF 向量化器已保存至 {TFIDF_PATH}")
|
||||
print(f"航空公司编码器已保存至 {AIRLINE_ENCODER_PATH}")
|
||||
|
||||
|
||||
def load_model() -> "TweetSentimentModel":
|
||||
"""加载训练好的模型
|
||||
|
||||
Returns:
|
||||
TweetSentimentModel: 训练好的模型
|
||||
"""
|
||||
if not MODEL_PATH.exists():
|
||||
raise FileNotFoundError(
|
||||
f"未找到模型文件 {MODEL_PATH}。请先运行 uv run python src/train_tweet_ultimate.py"
|
||||
)
|
||||
return TweetSentimentModel.load(MODEL_PATH, ENCODER_PATH, TFIDF_PATH, AIRLINE_ENCODER_PATH)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_ultimate_model()
|
||||
Loading…
Reference in New Issue
Block a user