akaAKR47/data_processing.py

129 lines
5.0 KiB
Python
Raw Normal View History

import polars as pl
import pandera.pandas as pa
from pandera.pandas import Column, DataFrameSchema, Check
import numpy as np
# 使用Pandera定义数据Schema
telco_schema = DataFrameSchema({
"customerID": Column(str, nullable=False),
"gender": Column(str, Check.isin(["Male", "Female"]), nullable=False),
"SeniorCitizen": Column(int, Check.isin([0, 1]), nullable=False),
"Partner": Column(str, Check.isin(["Yes", "No"]), nullable=False),
"Dependents": Column(str, Check.isin(["Yes", "No"]), nullable=False),
"tenure": Column(int, Check.ge(0), nullable=False),
"PhoneService": Column(str, Check.isin(["Yes", "No"]), nullable=False),
"MultipleLines": Column(str, Check.isin(["Yes", "No", "No phone service"]), nullable=False),
"InternetService": Column(str, Check.isin(["DSL", "Fiber optic", "No"]), nullable=False),
"OnlineSecurity": Column(str, Check.isin(["Yes", "No", "No internet service"]), nullable=False),
"OnlineBackup": Column(str, Check.isin(["Yes", "No", "No internet service"]), nullable=False),
"DeviceProtection": Column(str, Check.isin(["Yes", "No", "No internet service"]), nullable=False),
"TechSupport": Column(str, Check.isin(["Yes", "No", "No internet service"]), nullable=False),
"StreamingTV": Column(str, Check.isin(["Yes", "No", "No internet service"]), nullable=False),
"StreamingMovies": Column(str, Check.isin(["Yes", "No", "No internet service"]), nullable=False),
"Contract": Column(str, Check.isin(["Month-to-month", "One year", "Two year"]), nullable=False),
"PaperlessBilling": Column(str, Check.isin(["Yes", "No"]), nullable=False),
"PaymentMethod": Column(str, nullable=False),
"MonthlyCharges": Column(float, Check.ge(0), nullable=False),
"TotalCharges": Column(float, Check.ge(0), nullable=False),
"Churn": Column(str, Check.isin(["Yes", "No"]), nullable=False)
})
# 数据处理流水线
def data_processing_pipeline(file_path: str):
# 1. 读取数据
# 先将TotalCharges作为字符串读取以便处理空值
df = pl.read_csv(file_path, schema_overrides={"TotalCharges": pl.Utf8})
# 2. 数据清洗
# 处理TotalCharges列中的空值转换为0或均值
df = df.with_columns(
pl.col("TotalCharges")
.str.strip_chars()
.replace("", None)
.cast(pl.Float64)
)
# 填充缺失值使用0填充因为 tenure=0 时 TotalCharges 可能为0
df = df.with_columns(
pl.col("TotalCharges").fill_null(0.0)
)
# 3. 验证数据Schema
# 转换为pandas DataFrame进行Pandera验证
df_pandas = df.to_pandas()
validated_df_pandas = telco_schema.validate(df_pandas)
# 转换回Polars DataFrame
df = pl.from_pandas(validated_df_pandas)
# 4. 特征工程
# 将Churn列转换为0/1
df = df.with_columns(
pl.col("Churn").replace({"Yes": 1, "No": 0}).alias("Churn").cast(pl.Int64)
)
# 5. 分离特征和目标变量
X = df.drop(["customerID", "Churn"])
y = df.select("Churn")
return X, y, df
# 全局变量,用于存储特征处理信息
_encoded_columns = None
# 数据预处理(用于模型训练)
def preprocess_data(X: pl.DataFrame, y: pl.DataFrame):
global _encoded_columns
# 分类特征和数值特征
categorical_cols = X.select(pl.col(pl.Utf8)).columns
numerical_cols = X.select(pl.col(pl.Int64, pl.Float64)).columns
# 对分类特征进行独热编码
X_encoded = X.to_dummies(columns=categorical_cols)
# 保存编码后的列名
_encoded_columns = X_encoded.columns
# 转换为numpy数组
X_np = X_encoded.to_numpy()
y_np = y.to_numpy().ravel()
return X_np, y_np
# 数据预处理(用于单个客户预测)
def preprocess_single_customer(customer_data: pl.DataFrame):
global _encoded_columns
if _encoded_columns is None:
# 如果还没有编码列信息,加载训练数据并处理
_, _, df = data_processing_pipeline("data/Telco-Customer-Churn.csv")
X_train = df.drop(["customerID", "Churn"])
y_train = df.select("Churn")
preprocess_data(X_train, y_train)
# 对分类特征进行独热编码
categorical_cols = customer_data.select(pl.col(pl.Utf8)).columns
X_encoded = customer_data.to_dummies(columns=categorical_cols)
# 确保编码后的列与训练时的列一致
for col in _encoded_columns:
if col not in X_encoded.columns:
X_encoded = X_encoded.with_columns(pl.lit(0).alias(col))
# 按照训练时的列顺序排序
X_encoded = X_encoded.select(_encoded_columns)
# 转换为numpy数组
X_np = X_encoded.to_numpy()
return X_np
if __name__ == "__main__":
# 测试数据处理流水线
X, y, df = data_processing_pipeline("data/Telco-Customer-Churn.csv")
print("数据处理完成!")
print(f"特征数据形状: {X.shape}")
print(f"目标变量形状: {y.shape}")
print(f"清洗后的数据行数: {df.shape[0]}")