第二部分:数据采集模块
2.1 数据采集架构设计
数据采集是数据中台的第一道关卡,需要从多个异构数据源高效、可靠地采集数据。
┌─────────────────────────────────────────────────────────────────┐
│ 数据采集架构 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ 订单系统 │ │ 用户系统 │ │ 商品系统 │ │
│ │ (MySQL) │ │ (MySQL) │ │ (MongoDB) │ │
│ └──────┬───────┘ └──────┬───────┘ └──────┬───────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ Binlog同步 │ │ 定时拉取 │ │ 增量查询 │ │
│ │ (Canal) │ │ (APScheduler)│ │ (MongoDB) │ │
│ └──────┬───────┘ └──────┬───────┘ └──────┬───────┘ │
│ │ │ │ │
│ └───────────────────┼───────────────────┘ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ 消息队列(Kafka)│ │
│ └────────┬────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ 数据清洗/转换 │ │
│ │ (Spark/Pandas)│ │
│ └────────┬────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ 数据仓库/数据湖 │ │
│ │ (PostgreSQL/HDFS) │
│ └─────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
2.2 基类设计:抽象采集器
# src/collector/base.py
"""
数据采集器基类
所有具体的数据采集器都继承这个基类,实现统一的数据采集接口。
这样可以方便地添加新的数据源,并且统一处理错误、重试、日志等横切关注点。
"""
import abc
import time
from typing import Any, Dict, Iterator, List, Optional
from loguru import logger
from tenacity import retry, stop_after_attempt, wait_exponential
from src.utils.decorators import log_execution_time, monitor_metrics
class BaseCollector(abc.ABC):
"""
数据采集器基类
定义了数据采集的标准接口:
- collect(): 采集数据的主方法
- validate(): 验证采集到的数据
- transform(): 转换为统一格式
"""
def __init__(self, source_name: str, batch_size: int = 1000):
"""
初始化采集器
Args:
source_name: 数据源名称(用于日志和监控)
batch_size: 批量采集大小
"""
self.source_name = source_name
self.batch_size = batch_size
self.stats = {
"total_collected": 0,
"total_success": 0,
"total_failed": 0,
"last_collect_time": None,
}
@abc.abstractmethod
def connect(self) -> bool:
"""
建立与数据源的连接
Returns:
连接是否成功
"""
pass
@abc.abstractmethod
def fetch(self, last_timestamp: Optional[str] = None) -> Iterator[List[Dict]]:
"""
从数据源获取数据(分批返回)
Args:
last_timestamp: 上次采集的时间戳(用于增量采集)
Yields:
每批数据的列表
"""
pass
@abc.abstractmethod
def validate(self, data: Dict) -> bool:
"""
验证单条数据的有效性
Args:
data: 待验证的数据
Returns:
数据是否有效
"""
pass
@abc.abstractmethod
def transform(self, raw_data: Dict) -> Dict:
"""
将原始数据转换为统一格式
Args:
raw_data: 原始数据
Returns:
统一格式的数据
"""
pass
@log_execution_time
@monitor_metrics
@retry(
stop=stop_after_attempt(3), # 最多重试3次
wait=wait_exponential(multiplier=1, min=1, max=10), # 指数退避等待
)
def collect(self, incremental: bool = True) -> int:
"""
执行数据采集(带重试和监控)
Args:
incremental: 是否增量采集
Returns:
采集成功的记录数
"""
start_time = time.time()
logger.info(f"开始采集数据源: {self.source_name}, 增量模式: {incremental}")
# 获取上次采集的时间戳
last_ts = None
if incremental:
last_ts = self._get_last_timestamp()
success_count = 0
try:
for batch in self.fetch(last_ts):
for raw_data in batch:
try:
# 数据验证
if not self.validate(raw_data):
logger.warning(f"数据验证失败: {raw_data}")
self.stats["total_failed"] += 1
continue
# 数据转换
transformed = self.transform(raw_data)
# 写入数据仓库
self._write_to_warehouse(transformed)
success_count += 1
self.stats["total_success"] += 1
except Exception as e:
logger.error(f"处理数据失败: {e}, data={raw_data}")
self.stats["total_failed"] += 1
# 更新批次进度
self.stats["total_collected"] += len(batch)
logger.debug(f"已采集 {self.stats['total_collected']} 条记录")
# 更新最后采集时间
self._update_last_timestamp()
except Exception as e:
logger.error(f"采集失败: {e}")
raise
elapsed = time.time() - start_time
logger.info(
f"采集完成: 源={self.source_name}, "
f"成功={success_count}, "
f"失败={self.stats['total_failed']}, "
f"耗时={elapsed:.2f}s, "
f"速率={success_count/elapsed:.0f}条/秒"
)
self.stats["last_collect_time"] = time.time()
return success_count
def _get_last_timestamp(self) -> Optional[str]:
"""
获取上次采集的时间戳
从Redis或数据库读取,用于增量采集。
"""
# 这里简化实现,实际可以从Redis读取
return None
def _update_last_timestamp(self) -> None:
"""更新最后采集时间戳到Redis"""
# 这里简化实现,实际写入Redis
pass
def _write_to_warehouse(self, data: Dict) -> None:
"""
将数据写入数据仓库
子类可以覆盖此方法实现不同的写入逻辑。
"""
# 默认实现:打印数据
logger.debug(f"写入数据: {data}")
2.3 MySQL数据采集器(订单数据)
# src/collector/mysql_collector.py
"""
MySQL数据采集器
用于从MySQL数据库采集订单数据和用户数据。
支持全量采集和增量采集(基于更新时间戳)。
"""
import json
from typing import Any, Dict, Iterator, List, Optional
import pymysql
from pymysql.cursors import DictCursor
from src.collector.base import BaseCollector
from config.settings import settings
from src.utils.logger import logger
class MySQLCollector(BaseCollector):
"""
MySQL数据采集器
特点:
1. 支持断点续传(记录上次采集时间)
2. 分批查询,避免内存溢出
3. 连接池管理
"""
def __init__(self, table_name: str, db_name: str, incremental_column: str = "update_time"):
"""
初始化MySQL采集器
Args:
table_name: 要采集的表名
db_name: 数据库名(order_db/user_db)
incremental_column: 增量采集使用的列名(通常是update_time或create_time)
"""
super().__init__(source_name=f"mysql_{db_name}.{table_name}")
self.table_name = table_name
self.db_name = db_name
self.incremental_column = incremental_column
self._connection = None
def connect(self) -> bool:
"""建立MySQL连接"""
try:
# 根据数据库名选择不同的配置
if self.db_name == "order_db":
host = settings.MYSQL_HOST
port = settings.MYSQL_PORT
user = settings.MYSQL_USER
password = settings.MYSQL_PASSWORD
database = settings.MYSQL_ORDER_DB
else: # user_db
host = settings.MYSQL_HOST
port = settings.MYSQL_PORT
user = settings.MYSQL_USER
password = settings.MYSQL_PASSWORD
database = settings.MYSQL_USER_DB
self._connection = pymysql.connect(
host=host,
port=port,
user=user,
password=password,
database=database,
charset="utf8mb4",
cursorclass=DictCursor,
# 连接池配置
autocommit=False,
connect_timeout=10,
read_timeout=30,
write_timeout=30,
)
logger.info(f"MySQL连接成功: {self.db_name}.{self.table_name}")
return True
except Exception as e:
logger.error(f"MySQL连接失败: {e}")
return False
def fetch(self, last_timestamp: Optional[str] = None) -> Iterator[List[Dict]]:
"""
分批获取MySQL数据
使用OFFSET/LIMIT分批查询,避免一次加载过多数据。
对于增量采集,使用WHERE条件过滤。
"""
if not self._connection:
self.connect()
# 构建查询SQL
if last_timestamp:
# 增量采集:只查询更新的数据
sql = f"""
SELECT * FROM {self.table_name}
WHERE {self.incremental_column} >= %s
ORDER BY {self.incremental_column} ASC, id ASC
"""
params = [last_timestamp]
else:
# 全量采集:查询所有数据
sql = f"SELECT * FROM {self.table_name} ORDER BY id ASC"
params = []
offset = 0
while True:
# 分页查询
paginated_sql = sql + f" LIMIT {self.batch_size} OFFSET {offset}"
with self._connection.cursor() as cursor:
cursor.execute(paginated_sql, params)
rows = cursor.fetchall()
if not rows:
break
yield rows
offset += self.batch_size
logger.debug(f"已读取 {offset} 条记录")
def validate(self, data: Dict) -> bool:
"""验证MySQL数据"""
# 基本验证:必须有ID
if "id" not in data or data["id"] is None:
logger.warning("数据缺少id字段")
return False
# 根据表类型进行特定验证
if self.table_name == "orders":
# 订单必须有order_no和user_id
if not data.get("order_no") or not data.get("user_id"):
logger.warning(f"订单数据缺少必要字段: {data.get('id')}")
return False
elif self.table_name == "users":
# 用户必须有username
if not data.get("username"):
logger.warning(f"用户数据缺少username: {data.get('id')}")
return False
return True
def transform(self, raw_data: Dict) -> Dict:
"""
将MySQL数据转换为统一格式
转换内容包括:
1. 字段名统一(下划线转驼峰可选)
2. 时间格式统一
3. 空值处理
"""
transformed = {}
# 基础字段
transformed["source_id"] = str(raw_data["id"])
transformed["source_system"] = self.db_name
transformed["source_table"] = self.table_name
# 数据内容
data = {}
# 处理订单表
if self.table_name == "orders":
data = {
"order_id": raw_data["id"],
"order_no": raw_data.get("order_no", ""),
"user_id": raw_data.get("user_id"),
"total_amount": float(raw_data.get("total_amount", 0)),
"pay_amount": float(raw_data.get("pay_amount", 0)),
"discount_amount": float(raw_data.get("discount_amount", 0)),
"order_status": raw_data.get("order_status", 0),
"pay_status": raw_data.get("pay_status", 0),
"shipping_status": raw_data.get("shipping_status", 0),
"create_time": raw_data.get("create_time"),
"pay_time": raw_data.get("pay_time"),
"shipping_time": raw_data.get("shipping_time"),
"complete_time": raw_data.get("complete_time"),
"update_time": raw_data.get("update_time"),
}
# 处理用户表
elif self.table_name == "users":
data = {
"user_id": raw_data["id"],
"username": raw_data.get("username", ""),
"phone": raw_data.get("phone", ""),
"email": raw_data.get("email", ""),
"gender": raw_data.get("gender", 0),
"birthday": raw_data.get("birthday"),
"register_time": raw_data.get("create_time"),
"last_login_time": raw_data.get("last_login_time"),
"user_status": raw_data.get("status", 1),
"user_level": raw_data.get("level", 1),
}
transformed["data"] = data
transformed["etl_time"] = self._get_current_time()
transformed["etl_batch_id"] = self._get_batch_id()
return transformed
def _get_current_time(self) -> str:
"""获取当前时间字符串"""
from datetime import datetime
return datetime.now().isoformat()
def _get_batch_id(self) -> str:
"""获取批次ID(用于数据血缘追踪)"""
import uuid
return str(uuid.uuid4())
def close(self):
"""关闭数据库连接"""
if self._connection:
self._connection.close()
logger.info(f"MySQL连接已关闭: {self.source_name}")
2.4 MongoDB数据采集器(商品数据)
# src/collector/mongodb_collector.py
"""
MongoDB数据采集器
用于从MongoDB采集商品数据。
MongoDB是文档数据库,商品信息通常存储在嵌套文档中,需要特殊处理。
"""
from typing import Dict, Iterator, List, Optional
from pymongo import MongoClient
from pymongo.errors import PyMongoError
from src.collector.base import BaseCollector
from config.settings import settings
from src.utils.logger import logger
class MongoDBCollector(BaseCollector):
"""
MongoDB数据采集器
特点:
1. 支持嵌套文档的扁平化处理
2. 支持增量采集(基于ObjectId或时间戳)
3. 处理MongoDB特有的数据类型(ObjectId、ISODate等)
"""
def __init__(self, collection_name: str, incremental_field: str = "_id"):
"""
初始化MongoDB采集器
Args:
collection_name: 集合名称(如products)
incremental_field: 增量采集使用的字段
"""
super().__init__(source_name=f"mongodb.product_db.{collection_name}")
self.collection_name = collection_name
self.incremental_field = incremental_field
self._client = None
self._collection = None
def connect(self) -> bool:
"""建立MongoDB连接"""
try:
# 构建连接URI
if settings.MONGODB_USER and settings.MONGODB_PASSWORD:
uri = f"mongodb://{settings.MONGODB_USER}:{settings.MONGODB_PASSWORD}@" \
f"{settings.MONGODB_HOST}:{settings.MONGODB_PORT}"
else:
uri = f"mongodb://{settings.MONGODB_HOST}:{settings.MONGODB_PORT}"
self._client = MongoClient(uri, serverSelectionTimeoutMS=5000)
# 测试连接
self._client.admin.command('ping')
db = self._client[settings.MONGODB_DB]
self._collection = db[self.collection_name]
logger.info(f"MongoDB连接成功: {settings.MONGODB_DB}.{self.collection_name}")
return True
except PyMongoError as e:
logger.error(f"MongoDB连接失败: {e}")
return False
def fetch(self, last_timestamp: Optional[str] = None) -> Iterator[List[Dict]]:
"""分批获取MongoDB数据"""
if not self._collection:
self.connect()
# 构建查询条件
query = {}
if last_timestamp and self.incremental_field == "_id":
# 基于ObjectId的增量查询
from bson.objectid import ObjectId
query["_id"] = {"$gt": ObjectId(last_timestamp)}
elif last_timestamp:
query[self.incremental_field] = {"$gt": last_timestamp}
# 排序
sort_field = self.incremental_field if self.incremental_field == "_id" else self.incremental_field
cursor = self._collection.find(query).sort(sort_field, 1)
# 分批返回
batch = []
for doc in cursor:
batch.append(doc)
if len(batch) >= self.batch_size:
yield batch
batch = []
if batch:
yield batch
def validate(self, data: Dict) -> bool:
"""验证MongoDB数据"""
# 必须有_id
if "_id" not in data:
logger.warning("MongoDB数据缺少_id字段")
return False
# 商品必须有名称
if self.collection_name == "products" and not data.get("name"):
logger.warning(f"商品数据缺少name: {data.get('_id')}")
return False
return True
def _flatten_dict(self, data: Dict, parent_key: str = "", sep: str = "_") -> Dict:
"""
扁平化嵌套字典
MongoDB的文档经常包含嵌套结构,如:
{
"name": "iPhone 15",
"spec": {"color": "black", "storage": "256GB"},
"price": {"original": 9999, "current": 7999}
}
扁平化后变为:
{
"name": "iPhone 15",
"spec_color": "black",
"spec_storage": "256GB",
"price_original": 9999,
"price_current": 7999
}
"""
items = {}
for k, v in data.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.update(self._flatten_dict(v, new_key, sep=sep))
else:
items[new_key] = v
return items
def transform(self, raw_data: Dict) -> Dict:
"""
将MongoDB数据转换为统一格式
特殊处理:
1. ObjectId转换为字符串
2. ISODate转换为ISO格式字符串
3. 扁平化嵌套文档
"""
transformed = {}
# 处理ObjectId
from bson.objectid import ObjectId
if "_id" in raw_data and isinstance(raw_data["_id"], ObjectId):
transformed["source_id"] = str(raw_data["_id"])
raw_data["_id"] = str(raw_data["_id"]) # 替换为字符串
else:
transformed["source_id"] = str(raw_data.get("_id", ""))
transformed["source_system"] = "mongodb"
transformed["source_table"] = self.collection_name
# 扁平化文档
flattened = self._flatten_dict(raw_data)
# 处理日期类型
from datetime import datetime
for key, value in flattened.items():
if isinstance(value, datetime):
flattened[key] = value.isoformat()
transformed["data"] = flattened
transformed["etl_time"] = self._get_current_time()
transformed["etl_batch_id"] = self._get_batch_id()
return transformed
def _get_current_time(self) -> str:
from datetime import datetime
return datetime.now().isoformat()
def _get_batch_id(self) -> str:
import uuid
return str(uuid.uuid4())
def close(self):
"""关闭MongoDB连接"""
if self._client:
self._client.close()
logger.info(f"MongoDB连接已关闭: {self.source_name}")
第三部分:数据处理模块
3.1 数据清洗器
# src/processor/cleaner.py
"""
数据清洗模块
数据清洗是ETL过程中最关键的环节之一。
脏数据会导致分析结果错误,甚至影响业务决策。
常见的数据质量问题:
1. 缺失值:字段为空或NULL
2. 重复数据:同一实体有多条记录
3. 格式错误:日期格式不统一、金额符号混入等
4. 异常值:年龄1000岁、金额负值等
5. 不一致数据:同一用户在不同系统有不同手机号
"""
import re
from typing import Any, Dict, List, Optional
import pandas as pd
from loguru import logger
class DataCleaner:
"""
数据清洗器
提供常用的数据清洗方法:
- 去重
- 空值处理
- 格式标准化
- 异常值检测
"""
def __init__(self, config: Optional[Dict] = None):
"""
初始化清洗器
Args:
config: 清洗配置,包括:
- dedup_keys: 去重使用的字段列表
- fillna_strategy: 空值填充策略
- outlier_detection: 异常值检测规则
"""
self.config = config or {}
self.cleaning_stats = {
"rows_before": 0,
"rows_after": 0,
"duplicates_removed": 0,
"nulls_filled": 0,
"outliers_removed": 0,
}
def remove_duplicates(self, df: pd.DataFrame, subset: List[str] = None) -> pd.DataFrame:
"""
去除重复行
Args:
df: 输入DataFrame
subset: 用于判断重复的字段列表,默认使用所有字段
Returns:
去重后的DataFrame
"""
before = len(df)
if subset is None:
# 默认使用所有字段
df_deduplicated = df.drop_duplicates()
else:
df_deduplicated = df.drop_duplicates(subset=subset)
after = len(df_deduplicated)
removed = before - after
self.cleaning_stats["duplicates_removed"] += removed
self.cleaning_stats["rows_before"] += before
self.cleaning_stats["rows_after"] += after
if removed > 0:
logger.info(f"去除重复行: {removed} 条")
return df_deduplicated
def handle_missing_values(
self,
df: pd.DataFrame,
strategy: str = "auto",
fill_value: Any = None,
columns: List[str] = None
) -> pd.DataFrame:
"""
处理缺失值
策略说明:
- auto: 自动选择策略(数值列用中位数,类别列用众数)
- drop: 删除包含缺失值的行
- fill: 用指定值填充
- mean: 用均值填充(数值列)
- median: 用中位数填充(数值列)
- mode: 用众数填充(类别列)
- forward: 用前一个值填充(时间序列)
- backward: 用后一个值填充(时间序列)
Args:
df: 输入DataFrame
strategy: 填充策略
fill_value: 自定义填充值
columns: 要处理的列,默认所有列
Returns:
处理后的DataFrame
"""
if columns is None:
columns = df.columns.tolist()
for col in columns:
null_count = df[col].isna().sum()
if null_count == 0:
continue
logger.debug(f"列 {col} 有 {null_count} 个缺失值")
if strategy == "drop":
df = df.dropna(subset=[col])
self.cleaning_stats["nulls_filled"] += null_count
elif strategy == "fill" and fill_value is not None:
df[col] = df[col].fillna(fill_value)
self.cleaning_stats["nulls_filled"] += null_count
elif strategy == "mean" and pd.api.types.is_numeric_dtype(df[col]):
mean_val = df[col].mean()
df[col] = df[col].fillna(mean_val)
self.cleaning_stats["nulls_filled"] += null_count
elif strategy == "median" and pd.api.types.is_numeric_dtype(df[col]):
median_val = df[col].median()
df[col] = df[col].fillna(median_val)
self.cleaning_stats["nulls_filled"] += null_count
elif strategy == "mode":
mode_val = df[col].mode()
if len(mode_val) > 0:
df[col] = df[col].fillna(mode_val[0])
self.cleaning_stats["nulls_filled"] += null_count
elif strategy == "auto":
# 自动选择策略
if pd.api.types.is_numeric_dtype(df[col]):
# 数值列用中位数
median_val = df[col].median()
df[col] = df[col].fillna(median_val)
self.cleaning_stats["nulls_filled"] += null_count
else:
# 类别列用众数
mode_val = df[col].mode()
if len(mode_val) > 0:
df[col] = df[col].fillna(mode_val[0])
self.cleaning_stats["nulls_filled"] += null_count
else:
logger.warning(f"不支持的填充策略: {strategy}")
return df
def standardize_formats(self, df: pd.DataFrame) -> pd.DataFrame:
"""
标准化数据格式
处理内容:
1. 日期时间格式统一为 ISO 8601
2. 金额字段转为decimal
3. 手机号格式统一
4. 邮箱转小写
5. 去除字符串首尾空格
"""
# 日期时间列检测和转换
for col in df.columns:
if "time" in col.lower() or "date" in col.lower():
try:
df[col] = pd.to_datetime(df[col])
except (ValueError, TypeError):
pass
# 金额列处理
for col in df.columns:
if "amount" in col.lower() or "price" in col.lower() or "money" in col.lower():
if df[col].dtype == "object":
# 去除货币符号和逗号
df[col] = df[col].astype(str).str.replace(r'[^\d.-]', '', regex=True)
df[col] = pd.to_numeric(df[col], errors="coerce")
# 手机号格式统一
if "phone" in df.columns:
df["phone"] = df["phone"].astype(str).str.replace(r'\D', '', regex=True)
# 只保留11位手机号
df.loc[df["phone"].str.len() != 11, "phone"] = None
# 邮箱转小写
if "email" in df.columns:
df["email"] = df["email"].astype(str).str.lower().str.strip()
# 字符串去除首尾空格
for col in df.select_dtypes(include=["object"]).columns:
df[col] = df[col].astype(str).str.strip()
# 空字符串转为None
df.loc[df[col] == "", col] = None
return df
def detect_outliers(
self,
df: pd.DataFrame,
columns: List[str] = None,
method: str = "iqr",
threshold: float = 1.5
) -> pd.DataFrame:
"""
检测异常值
常用方法:
- iqr: 四分位距法(IQR = Q3 - Q1,超出 [Q1-1.5*IQR, Q3+1.5*IQR] 为异常)
- zscore: Z分数法(|Z| > 3 为异常)
Args:
df: 输入DataFrame
columns: 要检测的列
method: 检测方法
threshold: 阈值(IQR倍数或Z分数阈值)
Returns:
标记了异常值的DataFrame(添加_is_outlier列)
"""
if columns is None:
columns = df.select_dtypes(include=["number"]).columns.tolist()
df["_is_outlier"] = False
for col in columns:
if not pd.api.types.is_numeric_dtype(df[col]):
continue
if method == "iqr":
Q1 = df[col].quantile(0.25)
Q3 = df[col].quantile(0.75)
IQR = Q3 - Q1
lower_bound = Q1 - threshold * IQR
upper_bound = Q3 + threshold * IQR
col_outliers = (df[col] < lower_bound) | (df[col] > upper_bound)
elif method == "zscore":
mean = df[col].mean()
std = df[col].std()
if std == 0:
continue
z_scores = abs((df[col] - mean) / std)
col_outliers = z_scores > threshold
else:
logger.warning(f"不支持的异常检测方法: {method}")
continue
df["_is_outlier"] = df["_is_outlier"] | col_outliers
outlier_count = col_outliers.sum()
if outlier_count > 0:
logger.info(f"列 {col} 检测到 {outlier_count} 个异常值")
self.cleaning_stats["outliers_removed"] += outlier_count
return df
def clean(self, df: pd.DataFrame) -> pd.DataFrame:
"""
执行完整的数据清洗流程
Args:
df: 输入DataFrame
Returns:
清洗后的DataFrame
"""
logger.info(f"开始数据清洗,输入行数: {len(df)}")
# 1. 标准化格式
df = self.standardize_formats(df)
# 2. 处理缺失值
df = self.handle_missing_values(df, strategy="auto")
# 3. 去除重复
df = self.remove_duplicates(df)
# 4. 检测异常值(不删除,只标记)
df = self.detect_outliers(df)
# 5. 可选:删除异常值
if self.config.get("remove_outliers", False):
before = len(df)
df = df[~df["_is_outlier"]]
removed = before - len(df)
logger.info(f"删除异常值: {removed} 条")
logger.info(f"数据清洗完成,输出行数: {len(df)}")
return df
def get_stats(self) -> Dict:
"""获取清洗统计信息"""
return self.cleaning_stats
3.2 数据关联器(构建宽表)
# src/processor/joiner.py
"""
数据关联模块
将订单、用户、商品数据关联起来,构建分析宽表。
宽表是数据分析的基础,将所有相关字段放在一张表中,避免联表查询。
"""
from typing import Dict, List, Optional
import pandas as pd
from loguru import logger
class DataJoiner:
"""
数据关联器
将多张事实表和维度表关联成宽表。
典型的数据模型:
- 事实表:订单表(fact_orders)
- 维度表:用户表(dim_users)、商品表(dim_products)
- 宽表:订单宽表(包含订单、用户、商品的所有相关字段)
"""
def __init__(self):
self.join_stats = {}
def build_order_wide_table(
self,
orders_df: pd.DataFrame,
users_df: pd.DataFrame,
products_df: pd.DataFrame,
join_keys: Optional[Dict] = None
) -> pd.DataFrame:
"""
构建订单宽表
将订单表与用户表、商品表关联,生成包含所有分析字段的宽表。
Args:
orders_df: 订单事实表
users_df: 用户维度表
products_df: 商品维度表
join_keys: 关联键配置,默认:
- orders.user_id -> users.user_id
- orders.product_id -> products.product_id
Returns:
订单宽表
"""
if join_keys is None:
join_keys = {
"user": {"left": "user_id", "right": "user_id"},
"product": {"left": "product_id", "right": "product_id"}
}
before_count = len(orders_df)
logger.info(f"开始构建订单宽表,订单数: {before_count}")
# 1. 关联用户维度
orders_df = orders_df.merge(
users_df,
left_on=join_keys["user"]["left"],
right_on=join_keys["user"]["right"],
how="left",
suffixes=("", "_user")
)
# 2. 关联商品维度
orders_df = orders_df.merge(
products_df,
left_on=join_keys["product"]["left"],
right_on=join_keys["product"]["right"],
how="left",
suffixes=("", "_product")
)
after_count = len(orders_df)
# 统计关联情况
user_matched = orders_df["user_id_user"].notna().sum() if "user_id_user" in orders_df.columns else 0
product_matched = orders_df["product_id_product"].notna().sum() if "product_id_product" in orders_df.columns else 0
self.join_stats = {
"orders_total": before_count,
"orders_after_join": after_count,
"user_matched": user_matched,
"user_unmatched": before_count - user_matched,
"product_matched": product_matched,
"product_unmatched": before_count - product_matched,
}
logger.info(
f"订单宽表构建完成: "
f"用户匹配率={user_matched/before_count:.2%}, "
f"商品匹配率={product_matched/before_count:.2%}"
)
# 3. 添加派生字段
orders_df = self._add_derived_fields(orders_df)
return orders_df
def _add_derived_fields(self, df: pd.DataFrame) -> pd.DataFrame:
"""
添加派生字段
派生字段是从基础字段计算得出的,用于分析。
"""
# 日期维度
if "create_time" in df.columns:
df["create_date"] = pd.to_datetime(df["create_time"]).dt.date
df["create_year"] = pd.to_datetime(df["create_time"]).dt.year
df["create_month"] = pd.to_datetime(df["create_time"]).dt.month
df["create_day"] = pd.to_datetime(df["create_time"]).dt.day
df["create_hour"] = pd.to_datetime(df["create_time"]).dt.hour
df["create_weekday"] = pd.to_datetime(df["create_time"]).dt.dayofweek
df["create_week"] = pd.to_datetime(df["create_time"]).dt.isocalendar().week
# 金额相关
if "pay_amount" in df.columns:
# 是否全额支付
if "total_amount" in df.columns:
df["is_full_payment"] = df["pay_amount"] >= df["total_amount"]
# 时间差
if "pay_time" in df.columns and "create_time" in df.columns:
df["pay_delay_minutes"] = (
pd.to_datetime(df["pay_time"]) - pd.to_datetime(df["create_time"])
).dt.total_seconds() / 60
# 订单状态分组
if "order_status" in df.columns:
# 将状态码转换为中文
status_map = {
0: "待支付",
1: "已支付",
2: "已取消",
3: "已发货",
4: "已完成",
}
df["order_status_name"] = df["order_status"].map(status_map)
# 是否完成
df["is_completed"] = df["order_status"] == 4
return df
3.3 RFM用户分层模型
# src/processor/rfm_model.py
"""
RFM用户分层模型
RFM模型是衡量用户价值和用户创利能力的重要工具:
- R(Recency):最近一次消费时间,距离今天越近越好
- F(Frequency):消费频率,消费越频繁越好
- M(Monetary):消费金额,消费越多越好
通过RFM评分,可以将用户分为:
- 重要价值用户(高R高F高M)
- 重要发展用户(高R低F高M)
- 重要保持用户(低R高F高M)
- 重要挽留用户(低R低F高M)
- 一般价值用户(高R高F低M)
- 一般发展用户(高R低F低M)
- 一般保持用户(低R高F低M)
- 一般挽留用户(低R低F低M)
"""
from datetime import datetime, timedelta
from typing import Dict, List, Tuple
import pandas as pd
import numpy as np
from loguru import logger
class RFMModel:
"""
RFM用户分层模型
基于用户消费行为进行分层,用于精细化运营。
"""
def __init__(self, reference_date: datetime = None):
"""
初始化RFM模型
Args:
reference_date: 参考日期,默认为当前日期
"""
self.reference_date = reference_date or datetime.now()
self.scores = {}
def calculate_rfm(self, orders_df: pd.DataFrame) -> pd.DataFrame:
"""
计算用户的RFM指标
Args:
orders_df: 订单数据(必须包含user_id、create_time、pay_amount)
Returns:
用户的RFM指标表
"""
# 过滤已支付的订单
if "order_status" in orders_df.columns:
orders_df = orders_df[orders_df["order_status"] == 1].copy()
# 确保日期格式正确
orders_df["create_time"] = pd.to_datetime(orders_df["create_time"])
# 按用户分组计算
rfm_df = orders_df.groupby("user_id").agg({
"create_time": lambda x: (self.reference_date - x.max()).days, # R:最近一次消费间隔天数
"user_id": "count", # F:消费次数
"pay_amount": "sum", # M:消费总金额
}).rename(columns={
"create_time": "recency_days",
"user_id": "frequency",
"pay_amount": "monetary"
})
# 添加平均客单价
rfm_df["avg_order_value"] = rfm_df["monetary"] / rfm_df["frequency"]
logger.info(f"RFM计算完成,用户数: {len(rfm_df)}")
return rfm_df
def calculate_rfm_scores(
self,
rfm_df: pd.DataFrame,
r_bins: List[int] = None,
f_bins: List[int] = None,
m_bins: List[float] = None
) -> pd.DataFrame:
"""
计算RFM评分
将R、F、M分别按五分位进行评分(1-5分)。
Args:
rfm_df: RFM指标表
r_bins: R值分箱边界(越小分越高,因为R越小越好)
f_bins: F值分箱边界
m_bins: M值分箱边界
Returns:
带RFM评分的DataFrame
"""
if r_bins is None:
# R值越小越好,所以分箱边界从小到大
r_bins = [0, 30, 60, 90, 180, float('inf')]
r_labels = [5, 4, 3, 2, 1] # 最近消费的得5分
if f_bins is None:
f_bins = [0, 1, 3, 5, 10, float('inf')]
f_labels = [1, 2, 3, 4, 5] # 消费越频繁得分越高
if m_bins is None:
m_bins = [0, 500, 1000, 2000, 5000, float('inf')]
m_labels = [1, 2, 3, 4, 5] # 消费越多得分越高
# R评分
rfm_df["r_score"] = pd.cut(
rfm_df["recency_days"],
bins=r_bins,
labels=r_labels,
right=False
).astype(int)
# F评分
rfm_df["f_score"] = pd.cut(
rfm_df["frequency"],
bins=f_bins,
labels=f_labels,
right=False
).astype(int)
# M评分
rfm_df["m_score"] = pd.cut(
rfm_df["monetary"],
bins=m_bins,
labels=m_labels,
right=False
).astype(int)
# 综合评分
rfm_df["rfm_score"] = rfm_df["r_score"] * 100 + rfm_df["f_score"] * 10 + rfm_df["m_score"]
logger.info(f"RFM评分完成")
return rfm_df
def classify_users(self, rfm_df: pd.DataFrame) -> pd.DataFrame:
"""
根据RFM评分对用户进行分类
Args:
rfm_df: 带RFM评分的DataFrame
Returns:
带用户分类的DataFrame
"""
def get_user_segment(row):
"""根据R、F、M的评分确定用户类型"""
r, f, m = row["r_score"], row["f_score"], row["m_score"]
if r >= 4 and f >= 4 and m >= 4:
return "重要价值用户"
elif r >= 4 and f >= 4 and m < 4:
return "一般价值用户"
elif r >= 4 and f < 4 and m >= 4:
return "重要发展用户"
elif r >= 4 and f < 4 and m < 4:
return "一般发展用户"
elif r < 4 and f >= 4 and m >= 4:
return "重要保持用户"
elif r < 4 and f >= 4 and m < 4:
return "一般保持用户"
elif r < 4 and f < 4 and m >= 4:
return "重要挽留用户"
else:
return "一般挽留用户"
rfm_df["user_segment"] = rfm_df.apply(get_user_segment, axis=1)
# 统计各类用户数量
segment_stats = rfm_df["user_segment"].value_counts()
for segment, count in segment_stats.items():
logger.info(f"用户类型 {segment}: {count} 人 ({count/len(rfm_df)*100:.1f}%)")
return rfm_df
def run(self, orders_df: pd.DataFrame) -> pd.DataFrame:
"""
执行完整的RFM分析流程
Args:
orders_df: 订单数据
Returns:
RFM分析结果表
"""
logger.info("开始RFM分析")
# 1. 计算RFM指标
rfm_df = self.calculate_rfm(orders_df)
# 2. 计算RFM评分
rfm_df = self.calculate_rfm_scores(rfm_df)
# 3. 用户分类
rfm_df = self.classify_users(rfm_df)
logger.info(f"RFM分析完成,共分析 {len(rfm_df)} 个用户")
return rfm_df