import polars as pl import pandera.pandas as pa from pandera.pandas import Column, DataFrameSchema, Check import numpy as np # 使用Pandera定义数据Schema telco_schema = DataFrameSchema({ "customerID": Column(str, nullable=False), "gender": Column(str, Check.isin(["Male", "Female"]), nullable=False), "SeniorCitizen": Column(int, Check.isin([0, 1]), nullable=False), "Partner": Column(str, Check.isin(["Yes", "No"]), nullable=False), "Dependents": Column(str, Check.isin(["Yes", "No"]), nullable=False), "tenure": Column(int, Check.ge(0), nullable=False), "PhoneService": Column(str, Check.isin(["Yes", "No"]), nullable=False), "MultipleLines": Column(str, Check.isin(["Yes", "No", "No phone service"]), nullable=False), "InternetService": Column(str, Check.isin(["DSL", "Fiber optic", "No"]), nullable=False), "OnlineSecurity": Column(str, Check.isin(["Yes", "No", "No internet service"]), nullable=False), "OnlineBackup": Column(str, Check.isin(["Yes", "No", "No internet service"]), nullable=False), "DeviceProtection": Column(str, Check.isin(["Yes", "No", "No internet service"]), nullable=False), "TechSupport": Column(str, Check.isin(["Yes", "No", "No internet service"]), nullable=False), "StreamingTV": Column(str, Check.isin(["Yes", "No", "No internet service"]), nullable=False), "StreamingMovies": Column(str, Check.isin(["Yes", "No", "No internet service"]), nullable=False), "Contract": Column(str, Check.isin(["Month-to-month", "One year", "Two year"]), nullable=False), "PaperlessBilling": Column(str, Check.isin(["Yes", "No"]), nullable=False), "PaymentMethod": Column(str, nullable=False), "MonthlyCharges": Column(float, Check.ge(0), nullable=False), "TotalCharges": Column(float, Check.ge(0), nullable=False), "Churn": Column(str, Check.isin(["Yes", "No"]), nullable=False) }) # 数据处理流水线 def data_processing_pipeline(file_path: str): # 1. 读取数据 # 先将TotalCharges作为字符串读取,以便处理空值 df = pl.read_csv(file_path, schema_overrides={"TotalCharges": pl.Utf8}) # 2. 数据清洗 # 处理TotalCharges列中的空值(转换为0或均值) df = df.with_columns( pl.col("TotalCharges") .str.strip_chars() .replace("", None) .cast(pl.Float64) ) # 填充缺失值(使用0填充,因为 tenure=0 时 TotalCharges 可能为0) df = df.with_columns( pl.col("TotalCharges").fill_null(0.0) ) # 3. 验证数据Schema # 转换为pandas DataFrame进行Pandera验证 df_pandas = df.to_pandas() validated_df_pandas = telco_schema.validate(df_pandas) # 转换回Polars DataFrame df = pl.from_pandas(validated_df_pandas) # 4. 特征工程 # 将Churn列转换为0/1 df = df.with_columns( pl.col("Churn").replace({"Yes": 1, "No": 0}).alias("Churn").cast(pl.Int64) ) # 5. 分离特征和目标变量 X = df.drop(["customerID", "Churn"]) y = df.select("Churn") return X, y, df # 全局变量,用于存储特征处理信息 _encoded_columns = None # 数据预处理(用于模型训练) def preprocess_data(X: pl.DataFrame, y: pl.DataFrame): global _encoded_columns # 分类特征和数值特征 categorical_cols = X.select(pl.col(pl.Utf8)).columns numerical_cols = X.select(pl.col(pl.Int64, pl.Float64)).columns # 对分类特征进行独热编码 X_encoded = X.to_dummies(columns=categorical_cols) # 保存编码后的列名 _encoded_columns = X_encoded.columns # 转换为numpy数组 X_np = X_encoded.to_numpy() y_np = y.to_numpy().ravel() return X_np, y_np # 数据预处理(用于单个客户预测) def preprocess_single_customer(customer_data: pl.DataFrame): global _encoded_columns if _encoded_columns is None: # 如果还没有编码列信息,加载训练数据并处理 _, _, df = data_processing_pipeline("data/Telco-Customer-Churn.csv") X_train = df.drop(["customerID", "Churn"]) y_train = df.select("Churn") preprocess_data(X_train, y_train) # 对分类特征进行独热编码 categorical_cols = customer_data.select(pl.col(pl.Utf8)).columns X_encoded = customer_data.to_dummies(columns=categorical_cols) # 确保编码后的列与训练时的列一致 for col in _encoded_columns: if col not in X_encoded.columns: X_encoded = X_encoded.with_columns(pl.lit(0).alias(col)) # 按照训练时的列顺序排序 X_encoded = X_encoded.select(_encoded_columns) # 转换为numpy数组 X_np = X_encoded.to_numpy() return X_np if __name__ == "__main__": # 测试数据处理流水线 X, y, df = data_processing_pipeline("data/Telco-Customer-Churn.csv") print("数据处理完成!") print(f"特征数据形状: {X.shape}") print(f"目标变量形状: {y.shape}") print(f"清洗后的数据行数: {df.shape[0]}")