第三部分:召回算法
3.1 协同过滤召回
协同过滤是推荐系统最经典的算法,核心思想是:相似的用户喜欢相似的物品,相似的物品被相似的用户喜欢。
// smart-rec-recall/src/main/java/com/smartrec/recall/collaborative/ItemCFRecall.java
package com.smartrec.recall.collaborative;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.data.redis.core.RedisTemplate;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
/**
* 基于物品的协同过滤召回(ItemCF)
*
* 核心思想:如果用户A喜欢物品X,而物品X和物品Y相似,那么用户A也可能喜欢物品Y。
*
* 算法流程:
* 1. 计算物品相似度矩阵(离线计算)
* 2. 根据用户历史行为,找到用户喜欢的物品
* 3. 找到这些物品的相似物品,聚合后排序
*
* 物品相似度计算公式:
* sim(i, j) = |N(i) ∩ N(j)| / sqrt(|N(i)| * |N(j)|)
*
* 其中 N(i) 是喜欢物品i的用户集合
*
* 优点:
* - 可解释性强("因为您喜欢XXX,所以推荐YYY")
* - 物品相似度可以离线计算,在线速度快
* - 对新用户友好(只要有少量行为就能召回)
*
* 缺点:
* - 对新物品不友好(冷启动问题)
* - 倾向于推荐热门物品(需要做降权处理)
*/
@Slf4j
@Component
public class ItemCFRecall {
// 物品相似度矩阵(离线计算,定期更新)
// Map<itemId, Map<similarItemId, similarityScore>>
private final Map<Long, Map<Long, Double>> itemSimilarityMatrix = new ConcurrentHashMap<>();
// 热门物品惩罚系数(抑制热门物品过度推荐)
private static final double HOT_ITEM_PENALTY = 0.8;
private final RedisTemplate<String, Object> redisTemplate;
public ItemCFRecall(RedisTemplate<String, Object> redisTemplate) {
this.redisTemplate = redisTemplate;
// 启动时加载相似度矩阵
loadSimilarityMatrix();
// 启动定时更新任务
startUpdateTask();
}
/**
* 协同过滤召回
*
* @param userId 用户ID
* @param userActions 用户最近的行为序列
* @param topN 召回数量
* @return 召回的商品列表(带分数)
*/
public List<RecallItem> recall(Long userId, List<UserAction> userActions, int topN) {
// 1. 获取用户喜欢的物品(加权,不同行为类型权重不同)
Map<Long, Double> userPreferredItems = getUserPreferredItems(userActions);
if (userPreferredItems.isEmpty()) {
log.debug("No preferred items found for user: {}", userId);
return Collections.emptyList();
}
// 2. 基于物品相似度进行召回
Map<Long, Double> candidateItems = new HashMap<>();
for (Map.Entry<Long, Double> entry : userPreferredItems.entrySet()) {
Long likedItemId = entry.getKey();
Double weight = entry.getValue(); // 用户对该物品的偏好权重
// 获取该物品的相似物品列表
Map<Long, Double> similarItems = itemSimilarityMatrix.get(likedItemId);
if (similarItems == null) continue;
for (Map.Entry<Long, Double> simEntry : similarItems.entrySet()) {
Long candidateItemId = simEntry.getKey();
Double similarity = simEntry.getValue();
// 计算候选物品分数 = 用户偏好权重 * 物品相似度
double score = weight * similarity;
// 热门物品惩罚:降低热门物品的分数,增加多样性
double hotScore = getItemPopularity(candidateItemId);
score = score * Math.pow(HOT_ITEM_PENALTY, hotScore);
candidateItems.merge(candidateItemId, score, Double::sum);
}
}
// 3. 过滤用户已经交互过的物品
Set<Long> interactedItems = getInteractedItems(userActions);
candidateItems.entrySet().removeIf(entry -> interactedItems.contains(entry.getKey()));
// 4. 排序并返回TopN
return candidateItems.entrySet().stream()
.sorted(Map.Entry.<Long, Double>comparingByValue().reversed())
.limit(topN)
.map(entry -> new RecallItem(entry.getKey(), entry.getValue(), "itemcf"))
.collect(Collectors.toList());
}
/**
* 获取用户偏好的物品(加权)
*
* 不同行为类型的权重:
* - 点击:1
* - 收藏:2
* - 加购:3
* - 购买:5
*
* 时间衰减:越近的行为权重越高
*/
private Map<Long, Double> getUserPreferredItems(List<UserAction> userActions) {
Map<Long, Double> preferredItems = new HashMap<>();
long now = System.currentTimeMillis();
for (UserAction action : userActions) {
Long itemId = action.getItemId();
int actionType = action.getActionType();
// 行为权重
double actionWeight;
switch (actionType) {
case ACTION_CLICK: actionWeight = 1.0; break;
case ACTION_FAVORITE: actionWeight = 2.0; break;
case ACTION_CART: actionWeight = 3.0; break;
case ACTION_BUY: actionWeight = 5.0; break;
default: actionWeight = 0.5;
}
// 时间衰减:24小时内的行为权重最高,随时间指数衰减
long hoursAgo = (now - action.getTimestamp()) / (3600 * 1000);
double timeDecay = Math.exp(-hoursAgo / 48.0); // 半衰期48小时
double finalWeight = actionWeight * timeDecay;
preferredItems.merge(itemId, finalWeight, Double::sum);
}
return preferredItems;
}
/**
* 计算物品相似度矩阵
*
* 实际生产中,这个计算通常在离线Spark任务中完成,
* 这里展示核心算法逻辑。
*/
private void computeSimilarityMatrix() {
// 1. 获取用户-物品交互矩阵
// Map<userId, Set<itemId>>
Map<Long, Set<Long>> userItems = getUserItemMatrix();
// 2. 构建物品-用户倒排索引
// Map<itemId, Set<userId>>
Map<Long, Set<Long>> itemUsers = buildItemUserInvertedIndex(userItems);
// 3. 计算物品相似度
for (Long itemId1 : itemUsers.keySet()) {
Set<Long> users1 = itemUsers.get(itemId1);
Map<Long, Double> similarities = new HashMap<>();
for (Long itemId2 : itemUsers.keySet()) {
if (itemId1.equals(itemId2)) continue;
Set<Long> users2 = itemUsers.get(itemId2);
// 计算共同用户数量
Set<Long> intersection = new HashSet<>(users1);
intersection.retainAll(users2);
int commonUsers = intersection.size();
if (commonUsers == 0) continue;
// 余弦相似度
double similarity = commonUsers / Math.sqrt(users1.size() * users2.size());
// 只保留相似度大于阈值的物品
if (similarity > 0.05) {
similarities.put(itemId2, similarity);
}
}
// 只保留TopK个最相似的物品
List<Map.Entry<Long, Double>> topSimilar = similarities.entrySet().stream()
.sorted(Map.Entry.<Long, Double>comparingByValue().reversed())
.limit(200)
.collect(Collectors.toList());
Map<Long, Double> topSimilarMap = new HashMap<>();
for (Map.Entry<Long, Double> entry : topSimilar) {
topSimilarMap.put(entry.getKey(), entry.getValue());
}
itemSimilarityMatrix.put(itemId1, topSimilarMap);
}
log.info("Item similarity matrix computed: {} items", itemSimilarityMatrix.size());
}
private void loadSimilarityMatrix() {
// 从Redis加载预计算的相似度矩阵
String key = "recall:itemcf:similarity";
Map<Object, Object> matrixMap = redisTemplate.opsForHash().entries(key);
for (Map.Entry<Object, Object> entry : matrixMap.entrySet()) {
Long itemId = Long.valueOf(entry.getKey().toString());
// 反序列化相似物品列表
// 实际实现中会存储JSON格式
}
log.info("Item similarity matrix loaded: {} items", itemSimilarityMatrix.size());
}
private void startUpdateTask() {
// 每6小时更新一次相似度矩阵
ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor();
scheduler.scheduleAtFixedRate(() -> {
log.info("Updating item similarity matrix...");
loadSimilarityMatrix();
}, 6, 6, TimeUnit.HOURS);
}
private double getItemPopularity(Long itemId) {
// 获取物品的热门程度(归一化后0-1)
String key = "item:popularity:" + itemId;
Double popularity = (Double) redisTemplate.opsForValue().get(key);
return popularity != null ? popularity : 0.0;
}
private Set<Long> getInteractedItems(List<UserAction> userActions) {
return userActions.stream()
.map(UserAction::getItemId)
.collect(Collectors.toSet());
}
private Map<Long, Set<Long>> getUserItemMatrix() {
// 从数据仓库加载用户-物品交互矩阵
return new HashMap<>();
}
private Map<Long, Set<Long>> buildItemUserInvertedIndex(Map<Long, Set<Long>> userItems) {
Map<Long, Set<Long>> itemUsers = new HashMap<>();
for (Map.Entry<Long, Set<Long>> entry : userItems.entrySet()) {
Long userId = entry.getKey();
for (Long itemId : entry.getValue()) {
itemUsers.computeIfAbsent(itemId, k -> new HashSet<>()).add(userId);
}
}
return itemUsers;
}
}
3.2 向量召回(Embedding-based Recall)
向量召回是现代推荐系统的核心技术。通过将用户和物品映射到同一个向量空间,用户向量和物品向量的相似度可以直接反映用户对物品的兴趣程度。
// smart-rec-recall/src/main/java/com/smartrec/recall/vector/VectorRecall.java
package com.smartrec.recall.vector;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.data.redis.core.RedisTemplate;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
/**
* 向量召回
*
* 向量召回的核心思想是将用户和物品都表示为稠密向量(Embedding),
* 然后通过向量相似度(如余弦相似度、内积)来检索最相似的物品。
*
* 常用模型:
* 1. YouTube DNN:使用深度神经网络学习用户和物品向量
* 2. DSSM(深度语义匹配模型):双塔结构,用户塔和物品塔
* 3. Graph Embedding:如Node2Vec、EGES,利用图结构学习向量
* 4. BPR(贝叶斯个性化排序):基于矩阵分解
*
* 向量召回的优势:
* - 可以捕获用户和物品的深层语义信息
* - 泛化能力强,能发现用户可能喜欢的新物品
* - 支持实时更新(用户向量可以实时计算)
*
* 向量检索技术:
* - FAISS(Facebook AI Similarity Search):高效的大规模向量检索库
* - HNSW(Hierarchical Navigable Small World):基于图的近似最近邻搜索
*/
@Slf4j
@Component
public class VectorRecall {
// 物品向量库(离线训练得到)
// Map<itemId, float[] embedding>
private final Map<Long, float[]> itemEmbeddings = new ConcurrentHashMap<>();
// 用户向量缓存(可实时计算)
private final Map<Long, float[]> userEmbeddingCache = new ConcurrentHashMap<>();
// 向量检索索引(使用FAISS或自实现的HNSW)
private VectorIndex vectorIndex;
private final RedisTemplate<String, Object> redisTemplate;
public VectorRecall(RedisTemplate<String, Object> redisTemplate) {
this.redisTemplate = redisTemplate;
loadItemEmbeddings();
buildVectorIndex();
}
/**
* 向量召回
*
* @param userId 用户ID
* @param userActions 用户最近行为
* @param topN 召回数量
* @return 召回的商品列表
*/
public List<RecallItem> recall(Long userId, List<UserAction> userActions, int topN) {
// 1. 获取用户向量
float[] userEmbedding = getUserEmbedding(userId, userActions);
if (userEmbedding == null) {
log.debug("No user embedding for user: {}", userId);
return Collections.emptyList();
}
// 2. 向量检索:找到与用户向量最相似的TopK个物品向量
List<Long> candidateItems = vectorIndex.search(userEmbedding, topN * 2);
// 3. 计算相似度分数并过滤
List<RecallItem> results = new ArrayList<>();
Set<Long> interactedItems = getInteractedItems(userActions);
for (Long itemId : candidateItems) {
if (interactedItems.contains(itemId)) continue;
float[] itemEmbedding = itemEmbeddings.get(itemId);
if (itemEmbedding == null) continue;
double score = cosineSimilarity(userEmbedding, itemEmbedding);
results.add(new RecallItem(itemId, score, "vector"));
}
// 4. 排序返回
return results.stream()
.sorted(Comparator.comparingDouble(RecallItem::getScore).reversed())
.limit(topN)
.collect(Collectors.toList());
}
/**
* 获取用户向量
*
* 用户向量可以通过两种方式获得:
* 1. 端到端学习:模型直接输出用户向量(如YouTube DNN)
* 2. 行为聚合:根据用户历史行为的物品向量聚合得到
*
* 这里展示第二种方法:用户向量 = 用户历史物品向量的加权平均
*/
private float[] getUserEmbedding(Long userId, List<UserAction> userActions) {
// 先从缓存获取
if (userEmbeddingCache.containsKey(userId)) {
return userEmbeddingCache.get(userId);
}
// 如果没有行为,使用默认向量
if (userActions.isEmpty()) {
return getDefaultUserEmbedding();
}
// 聚合用户历史物品向量
float[] aggregated = null;
double totalWeight = 0;
for (UserAction action : userActions) {
float[] itemEmbed = itemEmbeddings.get(action.getItemId());
if (itemEmbed == null) continue;
// 行为权重
double weight = getActionWeight(action.getActionType());
// 时间衰减
long hoursAgo = (System.currentTimeMillis() - action.getTimestamp()) / (3600 * 1000);
weight *= Math.exp(-hoursAgo / 72.0); // 半衰期72小时
if (aggregated == null) {
aggregated = new float[itemEmbed.length];
}
for (int i = 0; i < itemEmbed.length; i++) {
aggregated[i] += itemEmbed[i] * weight;
}
totalWeight += weight;
}
if (aggregated == null) {
return getDefaultUserEmbedding();
}
// 归一化
for (int i = 0; i < aggregated.length; i++) {
aggregated[i] /= totalWeight;
}
// 缓存
userEmbeddingCache.put(userId, aggregated);
return aggregated;
}
/**
* 构建向量索引
*
* 对于百万级物品,线性搜索太慢,需要使用近似最近邻(ANN)算法。
* 这里展示使用FAISS构建索引的简化版本。
*/
private void buildVectorIndex() {
int dimension = 128; // 向量维度
int numItems = itemEmbeddings.size();
// 构建向量矩阵
float[][] vectors = new float[numItems][dimension];
long[] itemIds = new long[numItems];
int idx = 0;
for (Map.Entry<Long, float[]> entry : itemEmbeddings.entrySet()) {
itemIds[idx] = entry.getKey();
vectors[idx] = entry.getValue();
idx++;
}
// 创建索引
vectorIndex = new HNSWIndex(dimension, vectors, itemIds);
log.info("Vector index built: {} items, dimension={}", numItems, dimension);
}
private void loadItemEmbeddings() {
// 从文件或Redis加载物品向量
// 向量通常由离线训练得到,例如使用Word2Vec、Node2Vec或深度学习模型
String embeddingFile = "models/item_embeddings.txt";
try {
List<String> lines = Files.readAllLines(Paths.get(embeddingFile));
for (String line : lines) {
String[] parts = line.split(",");
Long itemId = Long.parseLong(parts[0]);
float[] embedding = new float[parts.length - 1];
for (int i = 1; i < parts.length; i++) {
embedding[i - 1] = Float.parseFloat(parts[i]);
}
itemEmbeddings.put(itemId, embedding);
}
log.info("Loaded {} item embeddings", itemEmbeddings.size());
} catch (Exception e) {
log.error("Failed to load item embeddings", e);
}
}
private float[] getDefaultUserEmbedding() {
// 返回零向量或全局平均向量
return new float[128];
}
private double cosineSimilarity(float[] a, float[] b) {
double dot = 0, normA = 0, normB = 0;
for (int i = 0; i < a.length; i++) {
dot += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
return dot / (Math.sqrt(normA) * Math.sqrt(normB));
}
private double getActionWeight(int actionType) {
switch (actionType) {
case ACTION_CLICK: return 1.0;
case ACTION_FAVORITE: return 2.0;
case ACTION_CART: return 3.0;
case ACTION_BUY: return 5.0;
default: return 0.5;
}
}
private Set<Long> getInteractedItems(List<UserAction> actions) {
return actions.stream().map(UserAction::getItemId).collect(Collectors.toSet());
}
}
3.3 多路召回合并
// smart-rec-recall/src/main/java/com/smartrec/recall/merge/MultiChannelRecallMerger.java
package com.smartrec.recall.merge;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
/**
* 多路召回合并器
*
* 单一召回算法都有其局限性,多路召回可以取长补短:
* - 协同过滤:利用用户相似性
* - 向量召回:利用语义相似性
* - 热门召回:保证基础流量
* - 品类召回:保证多样性
*
* 合并策略:
* 1. 加权合并:不同召回通道赋予不同权重
* 2. 排序合并:将所有候选统一打分排序
* 3. 分层合并:先按通道召回,再交叉融合
*/
@Slf4j
@Component
public class MultiChannelRecallMerger {
// 各召回通道的权重配置
private final Map<String, Double> channelWeights = new HashMap<>();
// 线程池,并行执行各通道召回
private final ExecutorService executor = Executors.newFixedThreadPool(8);
// 各召回通道实例
private final ItemCFRecall itemCFRecall;
private final VectorRecall vectorRecall;
private final HotRecall hotRecall;
private final CategoryRecall categoryRecall;
public MultiChannelRecallMerger(
ItemCFRecall itemCFRecall,
VectorRecall vectorRecall,
HotRecall hotRecall,
CategoryRecall categoryRecall) {
this.itemCFRecall = itemCFRecall;
this.vectorRecall = vectorRecall;
this.hotRecall = hotRecall;
this.categoryRecall = categoryRecall;
// 配置通道权重
channelWeights.put("itemcf", 1.0);
channelWeights.put("vector", 1.2);
channelWeights.put("hot", 0.5);
channelWeights.put("category", 0.8);
}
/**
* 多路召回合并
*
* 并行执行所有召回通道,然后合并结果。
*
* @param userId 用户ID
* @param userActions 用户行为
* @param topN 最终返回数量
* @return 合并后的推荐列表
*/
public List<RecallItem> merge(Long userId, List<UserAction> userActions, int topN) {
// 并行执行各通道召回
CompletableFuture<List<RecallItem>> itemCFFuture =
CompletableFuture.supplyAsync(() -> itemCFRecall.recall(userId, userActions, topN * 2), executor);
CompletableFuture<List<RecallItem>> vectorFuture =
CompletableFuture.supplyAsync(() -> vectorRecall.recall(userId, userActions, topN * 2), executor);
CompletableFuture<List<RecallItem>> hotFuture =
CompletableFuture.supplyAsync(() -> hotRecall.recall(userId, userActions, topN), executor);
CompletableFuture<List<RecallItem>> categoryFuture =
CompletableFuture.supplyAsync(() -> categoryRecall.recall(userId, userActions, topN), executor);
// 等待所有召回完成
CompletableFuture.allOf(itemCFFuture, vectorFuture, hotFuture, categoryFuture).join();
try {
List<RecallItem> itemCFResults = itemCFFuture.get();
List<RecallItem> vectorResults = vectorFuture.get();
List<RecallItem> hotResults = hotFuture.get();
List<RecallItem> categoryResults = categoryFuture.get();
// 合并结果
Map<Long, RecallItem> mergedMap = new HashMap<>();
mergeChannel(itemCFResults, mergedMap, "itemcf");
mergeChannel(vectorResults, mergedMap, "vector");
mergeChannel(hotResults, mergedMap, "hot");
mergeChannel(categoryResults, mergedMap, "category");
// 排序并返回
return mergedMap.values().stream()
.sorted(Comparator.comparingDouble(RecallItem::getScore).reversed())
.limit(topN)
.collect(Collectors.toList());
} catch (Exception e) {
log.error("Multi-channel recall failed", e);
return Collections.emptyList();
}
}
/**
* 合并单个通道的结果
*
* 使用加权分数合并:
* final_score = channel_weight * recall_score
*/
private void mergeChannel(List<RecallItem> channelResults,
Map<Long, RecallItem> mergedMap,
String channelName) {
double channelWeight = channelWeights.getOrDefault(channelName, 1.0);
for (RecallItem item : channelResults) {
double weightedScore = item.getScore() * channelWeight;
if (mergedMap.containsKey(item.getItemId())) {
// 去重:取最高分
RecallItem existing = mergedMap.get(item.getItemId());
if (weightedScore > existing.getScore()) {
existing.setScore(weightedScore);
existing.addChannel(channelName);
}
} else {
RecallItem newItem = new RecallItem(item.getItemId(), weightedScore, channelName);
mergedMap.put(item.getItemId(), newItem);
}
}
}
}
第四部分:排序模型
4.1 DeepFM排序模型
排序模型是推荐系统的核心,它决定了最终展示给用户的物品顺序。DeepFM结合了因子分解机(FM)和深度神经网络(DNN)的优势,能够自动学习低阶和高阶特征交叉。
// smart-rec-ranking/src/main/java/com/smartrec/ranking/model/DeepFMModel.java
package com.smartrec.ranking.model;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
import java.io.File;
import java.util.concurrent.atomic.AtomicLong;
/**
* DeepFM排序模型
*
* DeepFM由两部分组成:
* 1. FM部分:学习一阶特征和二阶特征交叉
* 2. Deep部分:学习高阶特征交叉
*
* 为什么选择DeepFM?
* - 端到端训练,无需人工特征工程
* - 同时捕获低阶和高阶特征交互
* - 训练速度快,适合大规模数据
*
* 模型结构:
* ┌─────────────────────────────────────────────────┐
* │ 输出层(Sigmoid) │
* │ CTR预测值 │
* └─────────────────────────────────────────────────┘
* ▲
* ┌───────────────┴───────────────┐
* │ │
* ┌──────┴──────┐ ┌──────┴──────┐
* │ FM部分 │ │ Deep部分 │
* │ 一阶+二阶 │ │ 多层DNN │
* └─────────────┘ └─────────────┘
* ▲ ▲
* └───────────────┬───────────────┘
* │
* ┌──────┴──────┐
* │ 输入层 │
* │ 稀疏特征 │
* └─────────────┘
*/
@Slf4j
@Component
public class DeepFMModel {
private MultiLayerNetwork model;
private final AtomicLong predictionCounter = new AtomicLong(0);
// 模型超参数
private static final int EMBEDDING_SIZE = 16;
private static final int DEEP_LAYER_1 = 256;
private static final int DEEP_LAYER_2 = 128;
private static final int DEEP_LAYER_3 = 64;
private static final double LEARNING_RATE = 0.001;
private static final int BATCH_SIZE = 512;
private static final int EPOCHS = 10;
@PostConstruct
public void init() {
loadOrCreateModel();
}
/**
* 构建DeepFM网络结构
*
* 输入特征包括:
* - 稀疏特征:用户ID、商品ID、类目ID、品牌ID等(需要Embedding)
* - 稠密特征:用户年龄、商品价格、统计特征等
*/
private MultiLayerNetwork buildModel(int numSparseFeatures, int numDenseFeatures) {
// 稀疏特征的总Embedding维度 = 稀疏特征数 × Embedding大小
int sparseEmbeddingDim = numSparseFeatures * EMBEDDING_SIZE;
int inputDim = sparseEmbeddingDim + numDenseFeatures;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(42)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Adam(LEARNING_RATE))
.weightInit(WeightInit.XAVIER)
.list()
// FM部分:一阶特征层(直接输出)
.layer(0, new DenseLayer.Builder()
.nIn(inputDim)
.nOut(1)
.activation(Activation.IDENTITY)
.weightInit(WeightInit.XAVIER)
.build())
// FM部分:二阶特征交叉(通过Embedding内积实现)
// 这里简化为一个单独的层
.layer(1, new DenseLayer.Builder()
.nIn(sparseEmbeddingDim)
.nOut(1)
.activation(Activation.IDENTITY)
.weightInit(WeightInit.XAVIER)
.build())
// Deep部分:多层神经网络
.layer(2, new DenseLayer.Builder()
.nIn(inputDim)
.nOut(DEEP_LAYER_1)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.dropOut(0.2)
.build())
.layer(3, new DenseLayer.Builder()
.nIn(DEEP_LAYER_1)
.nOut(DEEP_LAYER_2)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.dropOut(0.2)
.build())
.layer(4, new DenseLayer.Builder()
.nIn(DEEP_LAYER_2)
.nOut(DEEP_LAYER_3)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.dropOut(0.2)
.build())
// 输出层:合并FM和Deep的输出
.layer(5, new OutputLayer.Builder()
.nIn(DEEP_LAYER_3 + 2) // Deep输出 + FM一阶 + FM二阶
.nOut(1)
.activation(Activation.SIGMOID)
.lossFunction(LossFunctions.LossFunction.XENT)
.weightInit(WeightInit.XAVIER)
.build())
.backpropType(BackpropType.Standard)
.build();
return new MultiLayerNetwork(conf);
}
/**
* 训练模型
*
* 训练数据来自用户历史行为日志,
* 正样本是用户点击/购买的样本,
* 负样本是曝光但未点击的样本。
*/
public void train(List<TrainingSample> samples) {
log.info("Starting model training with {} samples", samples.size());
// 1. 特征预处理
List<float[]> features = new ArrayList<>();
List<Float> labels = new ArrayList<>();
for (TrainingSample sample : samples) {
float[] featureVector = extractFeatureVector(sample);
features.add(featureVector);
labels.add(sample.getLabel() ? 1.0f : 0.0f);
}
// 2. 转换为NDArray格式
// 实际训练代码省略...
log.info("Model training completed");
}
/**
* 在线预测
*
* @param featureVector 特征向量
* @return 点击率预测值(0-1之间)
*/
public float predict(float[] featureVector) {
long startTime = System.nanoTime();
// 模型推理
// INDArray input = Nd4j.create(featureVector);
// INDArray output = model.output(input);
// float ctr = output.getFloat(0);
// 简化实现:返回模拟分数
float ctr = 0.05f;
long duration = (System.nanoTime() - startTime) / 1000; // 微秒
predictionCounter.incrementAndGet();
// 慢查询监控
if (duration > 1000) {
log.warn("Slow prediction: {} μs", duration);
}
return ctr;
}
/**
* 批量预测
*
* 批量预测比单条预测更高效,可以利用向量化计算。
*/
public float[] predictBatch(List<float[]> featureVectors) {
float[] results = new float[featureVectors.size()];
for (int i = 0; i < featureVectors.size(); i++) {
results[i] = predict(featureVectors[i]);
}
return results;
}
private float[] extractFeatureVector(TrainingSample sample) {
// 将原始特征转换为数值向量
// 包括:
// - 稀疏特征Embedding
// - 稠密特征归一化
float[] vector = new float[128];
// 实现省略
return vector;
}
private void loadOrCreateModel() {
File modelFile = new File("models/deepfm.zip");
if (modelFile.exists()) {
// 加载已有模型
log.info("Loading existing model from {}", modelFile.getAbsolutePath());
} else {
// 创建新模型
log.info("Creating new model");
// model = buildModel(numSparseFeatures, numDenseFeatures);
// model.init();
}
}
public void saveModel(String path) {
log.info("Saving model to {}", path);
// model.save(new File(path));
}
public long getPredictionCount() {
return predictionCounter.get();
}
}