Java+AI实战:从零构建智能推荐系统(二)

简介: 教程来源 https://tmywi.cn/category/jiaju.html 本节详解推荐系统核心模块:第三部分“召回算法”涵盖协同过滤(ItemCF)、向量召回(Embedding+ANN)及多路融合策略;第四部分“排序模型”介绍DeepFM——融合FM低阶交叉与DNN高阶特征的CTR预估模型,兼顾可解释性与表达能力。

第三部分:召回算法

3.1 协同过滤召回
协同过滤是推荐系统最经典的算法,核心思想是:相似的用户喜欢相似的物品,相似的物品被相似的用户喜欢。

// smart-rec-recall/src/main/java/com/smartrec/recall/collaborative/ItemCFRecall.java
package com.smartrec.recall.collaborative;

import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.data.redis.core.RedisTemplate;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 基于物品的协同过滤召回(ItemCF)
 * 
 * 核心思想:如果用户A喜欢物品X,而物品X和物品Y相似,那么用户A也可能喜欢物品Y。
 * 
 * 算法流程:
 * 1. 计算物品相似度矩阵(离线计算)
 * 2. 根据用户历史行为,找到用户喜欢的物品
 * 3. 找到这些物品的相似物品,聚合后排序
 * 
 * 物品相似度计算公式:
 *   sim(i, j) = |N(i) ∩ N(j)| / sqrt(|N(i)| * |N(j)|)
 *   
 *   其中 N(i) 是喜欢物品i的用户集合
 * 
 * 优点:
 * - 可解释性强("因为您喜欢XXX,所以推荐YYY")
 * - 物品相似度可以离线计算,在线速度快
 * - 对新用户友好(只要有少量行为就能召回)
 * 
 * 缺点:
 * - 对新物品不友好(冷启动问题)
 * - 倾向于推荐热门物品(需要做降权处理)
 */
@Slf4j
@Component
public class ItemCFRecall {

    // 物品相似度矩阵(离线计算,定期更新)
    // Map<itemId, Map<similarItemId, similarityScore>>
    private final Map<Long, Map<Long, Double>> itemSimilarityMatrix = new ConcurrentHashMap<>();

    // 热门物品惩罚系数(抑制热门物品过度推荐)
    private static final double HOT_ITEM_PENALTY = 0.8;

    private final RedisTemplate<String, Object> redisTemplate;

    public ItemCFRecall(RedisTemplate<String, Object> redisTemplate) {
        this.redisTemplate = redisTemplate;
        // 启动时加载相似度矩阵
        loadSimilarityMatrix();
        // 启动定时更新任务
        startUpdateTask();
    }

    /**
     * 协同过滤召回
     * 
     * @param userId 用户ID
     * @param userActions 用户最近的行为序列
     * @param topN 召回数量
     * @return 召回的商品列表(带分数)
     */
    public List<RecallItem> recall(Long userId, List<UserAction> userActions, int topN) {
        // 1. 获取用户喜欢的物品(加权,不同行为类型权重不同)
        Map<Long, Double> userPreferredItems = getUserPreferredItems(userActions);

        if (userPreferredItems.isEmpty()) {
            log.debug("No preferred items found for user: {}", userId);
            return Collections.emptyList();
        }

        // 2. 基于物品相似度进行召回
        Map<Long, Double> candidateItems = new HashMap<>();

        for (Map.Entry<Long, Double> entry : userPreferredItems.entrySet()) {
            Long likedItemId = entry.getKey();
            Double weight = entry.getValue();  // 用户对该物品的偏好权重

            // 获取该物品的相似物品列表
            Map<Long, Double> similarItems = itemSimilarityMatrix.get(likedItemId);
            if (similarItems == null) continue;

            for (Map.Entry<Long, Double> simEntry : similarItems.entrySet()) {
                Long candidateItemId = simEntry.getKey();
                Double similarity = simEntry.getValue();

                // 计算候选物品分数 = 用户偏好权重 * 物品相似度
                double score = weight * similarity;

                // 热门物品惩罚:降低热门物品的分数,增加多样性
                double hotScore = getItemPopularity(candidateItemId);
                score = score * Math.pow(HOT_ITEM_PENALTY, hotScore);

                candidateItems.merge(candidateItemId, score, Double::sum);
            }
        }

        // 3. 过滤用户已经交互过的物品
        Set<Long> interactedItems = getInteractedItems(userActions);
        candidateItems.entrySet().removeIf(entry -> interactedItems.contains(entry.getKey()));

        // 4. 排序并返回TopN
        return candidateItems.entrySet().stream()
            .sorted(Map.Entry.<Long, Double>comparingByValue().reversed())
            .limit(topN)
            .map(entry -> new RecallItem(entry.getKey(), entry.getValue(), "itemcf"))
            .collect(Collectors.toList());
    }

    /**
     * 获取用户偏好的物品(加权)
     * 
     * 不同行为类型的权重:
     * - 点击:1
     * - 收藏:2
     * - 加购:3
     * - 购买:5
     * 
     * 时间衰减:越近的行为权重越高
     */
    private Map<Long, Double> getUserPreferredItems(List<UserAction> userActions) {
        Map<Long, Double> preferredItems = new HashMap<>();
        long now = System.currentTimeMillis();

        for (UserAction action : userActions) {
            Long itemId = action.getItemId();
            int actionType = action.getActionType();

            // 行为权重
            double actionWeight;
            switch (actionType) {
                case ACTION_CLICK: actionWeight = 1.0; break;
                case ACTION_FAVORITE: actionWeight = 2.0; break;
                case ACTION_CART: actionWeight = 3.0; break;
                case ACTION_BUY: actionWeight = 5.0; break;
                default: actionWeight = 0.5;
            }

            // 时间衰减:24小时内的行为权重最高,随时间指数衰减
            long hoursAgo = (now - action.getTimestamp()) / (3600 * 1000);
            double timeDecay = Math.exp(-hoursAgo / 48.0);  // 半衰期48小时

            double finalWeight = actionWeight * timeDecay;
            preferredItems.merge(itemId, finalWeight, Double::sum);
        }

        return preferredItems;
    }

    /**
     * 计算物品相似度矩阵
     * 
     * 实际生产中,这个计算通常在离线Spark任务中完成,
     * 这里展示核心算法逻辑。
     */
    private void computeSimilarityMatrix() {
        // 1. 获取用户-物品交互矩阵
        // Map<userId, Set<itemId>>
        Map<Long, Set<Long>> userItems = getUserItemMatrix();

        // 2. 构建物品-用户倒排索引
        // Map<itemId, Set<userId>>
        Map<Long, Set<Long>> itemUsers = buildItemUserInvertedIndex(userItems);

        // 3. 计算物品相似度
        for (Long itemId1 : itemUsers.keySet()) {
            Set<Long> users1 = itemUsers.get(itemId1);
            Map<Long, Double> similarities = new HashMap<>();

            for (Long itemId2 : itemUsers.keySet()) {
                if (itemId1.equals(itemId2)) continue;

                Set<Long> users2 = itemUsers.get(itemId2);

                // 计算共同用户数量
                Set<Long> intersection = new HashSet<>(users1);
                intersection.retainAll(users2);
                int commonUsers = intersection.size();

                if (commonUsers == 0) continue;

                // 余弦相似度
                double similarity = commonUsers / Math.sqrt(users1.size() * users2.size());

                // 只保留相似度大于阈值的物品
                if (similarity > 0.05) {
                    similarities.put(itemId2, similarity);
                }
            }

            // 只保留TopK个最相似的物品
            List<Map.Entry<Long, Double>> topSimilar = similarities.entrySet().stream()
                .sorted(Map.Entry.<Long, Double>comparingByValue().reversed())
                .limit(200)
                .collect(Collectors.toList());

            Map<Long, Double> topSimilarMap = new HashMap<>();
            for (Map.Entry<Long, Double> entry : topSimilar) {
                topSimilarMap.put(entry.getKey(), entry.getValue());
            }
            itemSimilarityMatrix.put(itemId1, topSimilarMap);
        }

        log.info("Item similarity matrix computed: {} items", itemSimilarityMatrix.size());
    }

    private void loadSimilarityMatrix() {
        // 从Redis加载预计算的相似度矩阵
        String key = "recall:itemcf:similarity";
        Map<Object, Object> matrixMap = redisTemplate.opsForHash().entries(key);

        for (Map.Entry<Object, Object> entry : matrixMap.entrySet()) {
            Long itemId = Long.valueOf(entry.getKey().toString());
            // 反序列化相似物品列表
            // 实际实现中会存储JSON格式
        }

        log.info("Item similarity matrix loaded: {} items", itemSimilarityMatrix.size());
    }

    private void startUpdateTask() {
        // 每6小时更新一次相似度矩阵
        ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor();
        scheduler.scheduleAtFixedRate(() -> {
            log.info("Updating item similarity matrix...");
            loadSimilarityMatrix();
        }, 6, 6, TimeUnit.HOURS);
    }

    private double getItemPopularity(Long itemId) {
        // 获取物品的热门程度(归一化后0-1)
        String key = "item:popularity:" + itemId;
        Double popularity = (Double) redisTemplate.opsForValue().get(key);
        return popularity != null ? popularity : 0.0;
    }

    private Set<Long> getInteractedItems(List<UserAction> userActions) {
        return userActions.stream()
            .map(UserAction::getItemId)
            .collect(Collectors.toSet());
    }

    private Map<Long, Set<Long>> getUserItemMatrix() {
        // 从数据仓库加载用户-物品交互矩阵
        return new HashMap<>();
    }

    private Map<Long, Set<Long>> buildItemUserInvertedIndex(Map<Long, Set<Long>> userItems) {
        Map<Long, Set<Long>> itemUsers = new HashMap<>();
        for (Map.Entry<Long, Set<Long>> entry : userItems.entrySet()) {
            Long userId = entry.getKey();
            for (Long itemId : entry.getValue()) {
                itemUsers.computeIfAbsent(itemId, k -> new HashSet<>()).add(userId);
            }
        }
        return itemUsers;
    }
}

3.2 向量召回(Embedding-based Recall)
向量召回是现代推荐系统的核心技术。通过将用户和物品映射到同一个向量空间,用户向量和物品向量的相似度可以直接反映用户对物品的兴趣程度。

// smart-rec-recall/src/main/java/com/smartrec/recall/vector/VectorRecall.java
package com.smartrec.recall.vector;

import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.data.redis.core.RedisTemplate;

import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 向量召回
 * 
 * 向量召回的核心思想是将用户和物品都表示为稠密向量(Embedding),
 * 然后通过向量相似度(如余弦相似度、内积)来检索最相似的物品。
 * 
 * 常用模型:
 * 1. YouTube DNN:使用深度神经网络学习用户和物品向量
 * 2. DSSM(深度语义匹配模型):双塔结构,用户塔和物品塔
 * 3. Graph Embedding:如Node2Vec、EGES,利用图结构学习向量
 * 4. BPR(贝叶斯个性化排序):基于矩阵分解
 * 
 * 向量召回的优势:
 * - 可以捕获用户和物品的深层语义信息
 * - 泛化能力强,能发现用户可能喜欢的新物品
 * - 支持实时更新(用户向量可以实时计算)
 * 
 * 向量检索技术:
 * - FAISS(Facebook AI Similarity Search):高效的大规模向量检索库
 * - HNSW(Hierarchical Navigable Small World):基于图的近似最近邻搜索
 */
@Slf4j
@Component
public class VectorRecall {

    // 物品向量库(离线训练得到)
    // Map<itemId, float[] embedding>
    private final Map<Long, float[]> itemEmbeddings = new ConcurrentHashMap<>();

    // 用户向量缓存(可实时计算)
    private final Map<Long, float[]> userEmbeddingCache = new ConcurrentHashMap<>();

    // 向量检索索引(使用FAISS或自实现的HNSW)
    private VectorIndex vectorIndex;

    private final RedisTemplate<String, Object> redisTemplate;

    public VectorRecall(RedisTemplate<String, Object> redisTemplate) {
        this.redisTemplate = redisTemplate;
        loadItemEmbeddings();
        buildVectorIndex();
    }

    /**
     * 向量召回
     * 
     * @param userId 用户ID
     * @param userActions 用户最近行为
     * @param topN 召回数量
     * @return 召回的商品列表
     */
    public List<RecallItem> recall(Long userId, List<UserAction> userActions, int topN) {
        // 1. 获取用户向量
        float[] userEmbedding = getUserEmbedding(userId, userActions);

        if (userEmbedding == null) {
            log.debug("No user embedding for user: {}", userId);
            return Collections.emptyList();
        }

        // 2. 向量检索:找到与用户向量最相似的TopK个物品向量
        List<Long> candidateItems = vectorIndex.search(userEmbedding, topN * 2);

        // 3. 计算相似度分数并过滤
        List<RecallItem> results = new ArrayList<>();
        Set<Long> interactedItems = getInteractedItems(userActions);

        for (Long itemId : candidateItems) {
            if (interactedItems.contains(itemId)) continue;

            float[] itemEmbedding = itemEmbeddings.get(itemId);
            if (itemEmbedding == null) continue;

            double score = cosineSimilarity(userEmbedding, itemEmbedding);
            results.add(new RecallItem(itemId, score, "vector"));
        }

        // 4. 排序返回
        return results.stream()
            .sorted(Comparator.comparingDouble(RecallItem::getScore).reversed())
            .limit(topN)
            .collect(Collectors.toList());
    }

    /**
     * 获取用户向量
     * 
     * 用户向量可以通过两种方式获得:
     * 1. 端到端学习:模型直接输出用户向量(如YouTube DNN)
     * 2. 行为聚合:根据用户历史行为的物品向量聚合得到
     * 
     * 这里展示第二种方法:用户向量 = 用户历史物品向量的加权平均
     */
    private float[] getUserEmbedding(Long userId, List<UserAction> userActions) {
        // 先从缓存获取
        if (userEmbeddingCache.containsKey(userId)) {
            return userEmbeddingCache.get(userId);
        }

        // 如果没有行为,使用默认向量
        if (userActions.isEmpty()) {
            return getDefaultUserEmbedding();
        }

        // 聚合用户历史物品向量
        float[] aggregated = null;
        double totalWeight = 0;

        for (UserAction action : userActions) {
            float[] itemEmbed = itemEmbeddings.get(action.getItemId());
            if (itemEmbed == null) continue;

            // 行为权重
            double weight = getActionWeight(action.getActionType());
            // 时间衰减
            long hoursAgo = (System.currentTimeMillis() - action.getTimestamp()) / (3600 * 1000);
            weight *= Math.exp(-hoursAgo / 72.0);  // 半衰期72小时

            if (aggregated == null) {
                aggregated = new float[itemEmbed.length];
            }

            for (int i = 0; i < itemEmbed.length; i++) {
                aggregated[i] += itemEmbed[i] * weight;
            }
            totalWeight += weight;
        }

        if (aggregated == null) {
            return getDefaultUserEmbedding();
        }

        // 归一化
        for (int i = 0; i < aggregated.length; i++) {
            aggregated[i] /= totalWeight;
        }

        // 缓存
        userEmbeddingCache.put(userId, aggregated);

        return aggregated;
    }

    /**
     * 构建向量索引
     * 
     * 对于百万级物品,线性搜索太慢,需要使用近似最近邻(ANN)算法。
     * 这里展示使用FAISS构建索引的简化版本。
     */
    private void buildVectorIndex() {
        int dimension = 128;  // 向量维度
        int numItems = itemEmbeddings.size();

        // 构建向量矩阵
        float[][] vectors = new float[numItems][dimension];
        long[] itemIds = new long[numItems];

        int idx = 0;
        for (Map.Entry<Long, float[]> entry : itemEmbeddings.entrySet()) {
            itemIds[idx] = entry.getKey();
            vectors[idx] = entry.getValue();
            idx++;
        }

        // 创建索引
        vectorIndex = new HNSWIndex(dimension, vectors, itemIds);
        log.info("Vector index built: {} items, dimension={}", numItems, dimension);
    }

    private void loadItemEmbeddings() {
        // 从文件或Redis加载物品向量
        // 向量通常由离线训练得到,例如使用Word2Vec、Node2Vec或深度学习模型

        String embeddingFile = "models/item_embeddings.txt";
        try {
            List<String> lines = Files.readAllLines(Paths.get(embeddingFile));
            for (String line : lines) {
                String[] parts = line.split(",");
                Long itemId = Long.parseLong(parts[0]);
                float[] embedding = new float[parts.length - 1];
                for (int i = 1; i < parts.length; i++) {
                    embedding[i - 1] = Float.parseFloat(parts[i]);
                }
                itemEmbeddings.put(itemId, embedding);
            }
            log.info("Loaded {} item embeddings", itemEmbeddings.size());
        } catch (Exception e) {
            log.error("Failed to load item embeddings", e);
        }
    }

    private float[] getDefaultUserEmbedding() {
        // 返回零向量或全局平均向量
        return new float[128];
    }

    private double cosineSimilarity(float[] a, float[] b) {
        double dot = 0, normA = 0, normB = 0;
        for (int i = 0; i < a.length; i++) {
            dot += a[i] * b[i];
            normA += a[i] * a[i];
            normB += b[i] * b[i];
        }
        return dot / (Math.sqrt(normA) * Math.sqrt(normB));
    }

    private double getActionWeight(int actionType) {
        switch (actionType) {
            case ACTION_CLICK: return 1.0;
            case ACTION_FAVORITE: return 2.0;
            case ACTION_CART: return 3.0;
            case ACTION_BUY: return 5.0;
            default: return 0.5;
        }
    }

    private Set<Long> getInteractedItems(List<UserAction> actions) {
        return actions.stream().map(UserAction::getItemId).collect(Collectors.toSet());
    }
}

3.3 多路召回合并

// smart-rec-recall/src/main/java/com/smartrec/recall/merge/MultiChannelRecallMerger.java
package com.smartrec.recall.merge;

import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;

/**
 * 多路召回合并器
 * 
 * 单一召回算法都有其局限性,多路召回可以取长补短:
 * - 协同过滤:利用用户相似性
 * - 向量召回:利用语义相似性
 * - 热门召回:保证基础流量
 * - 品类召回:保证多样性
 * 
 * 合并策略:
 * 1. 加权合并:不同召回通道赋予不同权重
 * 2. 排序合并:将所有候选统一打分排序
 * 3. 分层合并:先按通道召回,再交叉融合
 */
@Slf4j
@Component
public class MultiChannelRecallMerger {

    // 各召回通道的权重配置
    private final Map<String, Double> channelWeights = new HashMap<>();

    // 线程池,并行执行各通道召回
    private final ExecutorService executor = Executors.newFixedThreadPool(8);

    // 各召回通道实例
    private final ItemCFRecall itemCFRecall;
    private final VectorRecall vectorRecall;
    private final HotRecall hotRecall;
    private final CategoryRecall categoryRecall;

    public MultiChannelRecallMerger(
            ItemCFRecall itemCFRecall,
            VectorRecall vectorRecall,
            HotRecall hotRecall,
            CategoryRecall categoryRecall) {
        this.itemCFRecall = itemCFRecall;
        this.vectorRecall = vectorRecall;
        this.hotRecall = hotRecall;
        this.categoryRecall = categoryRecall;

        // 配置通道权重
        channelWeights.put("itemcf", 1.0);
        channelWeights.put("vector", 1.2);
        channelWeights.put("hot", 0.5);
        channelWeights.put("category", 0.8);
    }

    /**
     * 多路召回合并
     * 
     * 并行执行所有召回通道,然后合并结果。
     * 
     * @param userId 用户ID
     * @param userActions 用户行为
     * @param topN 最终返回数量
     * @return 合并后的推荐列表
     */
    public List<RecallItem> merge(Long userId, List<UserAction> userActions, int topN) {
        // 并行执行各通道召回
        CompletableFuture<List<RecallItem>> itemCFFuture = 
            CompletableFuture.supplyAsync(() -> itemCFRecall.recall(userId, userActions, topN * 2), executor);

        CompletableFuture<List<RecallItem>> vectorFuture = 
            CompletableFuture.supplyAsync(() -> vectorRecall.recall(userId, userActions, topN * 2), executor);

        CompletableFuture<List<RecallItem>> hotFuture = 
            CompletableFuture.supplyAsync(() -> hotRecall.recall(userId, userActions, topN), executor);

        CompletableFuture<List<RecallItem>> categoryFuture = 
            CompletableFuture.supplyAsync(() -> categoryRecall.recall(userId, userActions, topN), executor);

        // 等待所有召回完成
        CompletableFuture.allOf(itemCFFuture, vectorFuture, hotFuture, categoryFuture).join();

        try {
            List<RecallItem> itemCFResults = itemCFFuture.get();
            List<RecallItem> vectorResults = vectorFuture.get();
            List<RecallItem> hotResults = hotFuture.get();
            List<RecallItem> categoryResults = categoryFuture.get();

            // 合并结果
            Map<Long, RecallItem> mergedMap = new HashMap<>();

            mergeChannel(itemCFResults, mergedMap, "itemcf");
            mergeChannel(vectorResults, mergedMap, "vector");
            mergeChannel(hotResults, mergedMap, "hot");
            mergeChannel(categoryResults, mergedMap, "category");

            // 排序并返回
            return mergedMap.values().stream()
                .sorted(Comparator.comparingDouble(RecallItem::getScore).reversed())
                .limit(topN)
                .collect(Collectors.toList());

        } catch (Exception e) {
            log.error("Multi-channel recall failed", e);
            return Collections.emptyList();
        }
    }

    /**
     * 合并单个通道的结果
     * 
     * 使用加权分数合并:
     *   final_score = channel_weight * recall_score
     */
    private void mergeChannel(List<RecallItem> channelResults, 
                               Map<Long, RecallItem> mergedMap, 
                               String channelName) {
        double channelWeight = channelWeights.getOrDefault(channelName, 1.0);

        for (RecallItem item : channelResults) {
            double weightedScore = item.getScore() * channelWeight;

            if (mergedMap.containsKey(item.getItemId())) {
                // 去重:取最高分
                RecallItem existing = mergedMap.get(item.getItemId());
                if (weightedScore > existing.getScore()) {
                    existing.setScore(weightedScore);
                    existing.addChannel(channelName);
                }
            } else {
                RecallItem newItem = new RecallItem(item.getItemId(), weightedScore, channelName);
                mergedMap.put(item.getItemId(), newItem);
            }
        }
    }
}

第四部分:排序模型

4.1 DeepFM排序模型
排序模型是推荐系统的核心,它决定了最终展示给用户的物品顺序。DeepFM结合了因子分解机(FM)和深度神经网络(DNN)的优势,能够自动学习低阶和高阶特征交叉。

// smart-rec-ranking/src/main/java/com/smartrec/ranking/model/DeepFMModel.java
package com.smartrec.ranking.model;

import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import java.io.File;
import java.util.concurrent.atomic.AtomicLong;

/**
 * DeepFM排序模型
 * 
 * DeepFM由两部分组成:
 * 1. FM部分:学习一阶特征和二阶特征交叉
 * 2. Deep部分:学习高阶特征交叉
 * 
 * 为什么选择DeepFM?
 * - 端到端训练,无需人工特征工程
 * - 同时捕获低阶和高阶特征交互
 * - 训练速度快,适合大规模数据
 * 
 * 模型结构:
 *   ┌─────────────────────────────────────────────────┐
 *   │                   输出层(Sigmoid)              │
 *   │                    CTR预测值                     │
 *   └─────────────────────────────────────────────────┘
 *                          ▲
 *          ┌───────────────┴───────────────┐
 *          │                               │
 *   ┌──────┴──────┐                 ┌──────┴──────┐
 *   │   FM部分    │                 │  Deep部分   │
 *   │ 一阶+二阶   │                 │  多层DNN    │
 *   └─────────────┘                 └─────────────┘
 *          ▲                               ▲
 *          └───────────────┬───────────────┘
 *                          │
 *                   ┌──────┴──────┐
 *                   │  输入层      │
 *                   │ 稀疏特征     │
 *                   └─────────────┘
 */
@Slf4j
@Component
public class DeepFMModel {

    private MultiLayerNetwork model;
    private final AtomicLong predictionCounter = new AtomicLong(0);

    // 模型超参数
    private static final int EMBEDDING_SIZE = 16;
    private static final int DEEP_LAYER_1 = 256;
    private static final int DEEP_LAYER_2 = 128;
    private static final int DEEP_LAYER_3 = 64;
    private static final double LEARNING_RATE = 0.001;
    private static final int BATCH_SIZE = 512;
    private static final int EPOCHS = 10;

    @PostConstruct
    public void init() {
        loadOrCreateModel();
    }

    /**
     * 构建DeepFM网络结构
     * 
     * 输入特征包括:
     * - 稀疏特征:用户ID、商品ID、类目ID、品牌ID等(需要Embedding)
     * - 稠密特征:用户年龄、商品价格、统计特征等
     */
    private MultiLayerNetwork buildModel(int numSparseFeatures, int numDenseFeatures) {
        // 稀疏特征的总Embedding维度 = 稀疏特征数 × Embedding大小
        int sparseEmbeddingDim = numSparseFeatures * EMBEDDING_SIZE;
        int inputDim = sparseEmbeddingDim + numDenseFeatures;

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(42)
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .updater(new Adam(LEARNING_RATE))
            .weightInit(WeightInit.XAVIER)
            .list()
            // FM部分:一阶特征层(直接输出)
            .layer(0, new DenseLayer.Builder()
                .nIn(inputDim)
                .nOut(1)
                .activation(Activation.IDENTITY)
                .weightInit(WeightInit.XAVIER)
                .build())
            // FM部分:二阶特征交叉(通过Embedding内积实现)
            // 这里简化为一个单独的层
            .layer(1, new DenseLayer.Builder()
                .nIn(sparseEmbeddingDim)
                .nOut(1)
                .activation(Activation.IDENTITY)
                .weightInit(WeightInit.XAVIER)
                .build())
            // Deep部分:多层神经网络
            .layer(2, new DenseLayer.Builder()
                .nIn(inputDim)
                .nOut(DEEP_LAYER_1)
                .activation(Activation.RELU)
                .weightInit(WeightInit.XAVIER)
                .dropOut(0.2)
                .build())
            .layer(3, new DenseLayer.Builder()
                .nIn(DEEP_LAYER_1)
                .nOut(DEEP_LAYER_2)
                .activation(Activation.RELU)
                .weightInit(WeightInit.XAVIER)
                .dropOut(0.2)
                .build())
            .layer(4, new DenseLayer.Builder()
                .nIn(DEEP_LAYER_2)
                .nOut(DEEP_LAYER_3)
                .activation(Activation.RELU)
                .weightInit(WeightInit.XAVIER)
                .dropOut(0.2)
                .build())
            // 输出层:合并FM和Deep的输出
            .layer(5, new OutputLayer.Builder()
                .nIn(DEEP_LAYER_3 + 2)  // Deep输出 + FM一阶 + FM二阶
                .nOut(1)
                .activation(Activation.SIGMOID)
                .lossFunction(LossFunctions.LossFunction.XENT)
                .weightInit(WeightInit.XAVIER)
                .build())
            .backpropType(BackpropType.Standard)
            .build();

        return new MultiLayerNetwork(conf);
    }

    /**
     * 训练模型
     * 
     * 训练数据来自用户历史行为日志,
     * 正样本是用户点击/购买的样本,
     * 负样本是曝光但未点击的样本。
     */
    public void train(List<TrainingSample> samples) {
        log.info("Starting model training with {} samples", samples.size());

        // 1. 特征预处理
        List<float[]> features = new ArrayList<>();
        List<Float> labels = new ArrayList<>();

        for (TrainingSample sample : samples) {
            float[] featureVector = extractFeatureVector(sample);
            features.add(featureVector);
            labels.add(sample.getLabel() ? 1.0f : 0.0f);
        }

        // 2. 转换为NDArray格式
        // 实际训练代码省略...

        log.info("Model training completed");
    }

    /**
     * 在线预测
     * 
     * @param featureVector 特征向量
     * @return 点击率预测值(0-1之间)
     */
    public float predict(float[] featureVector) {
        long startTime = System.nanoTime();

        // 模型推理
        // INDArray input = Nd4j.create(featureVector);
        // INDArray output = model.output(input);
        // float ctr = output.getFloat(0);

        // 简化实现:返回模拟分数
        float ctr = 0.05f;

        long duration = (System.nanoTime() - startTime) / 1000;  // 微秒
        predictionCounter.incrementAndGet();

        // 慢查询监控
        if (duration > 1000) {
            log.warn("Slow prediction: {} μs", duration);
        }

        return ctr;
    }

    /**
     * 批量预测
     * 
     * 批量预测比单条预测更高效,可以利用向量化计算。
     */
    public float[] predictBatch(List<float[]> featureVectors) {
        float[] results = new float[featureVectors.size()];
        for (int i = 0; i < featureVectors.size(); i++) {
            results[i] = predict(featureVectors[i]);
        }
        return results;
    }

    private float[] extractFeatureVector(TrainingSample sample) {
        // 将原始特征转换为数值向量
        // 包括:
        // - 稀疏特征Embedding
        // - 稠密特征归一化
        float[] vector = new float[128];
        // 实现省略
        return vector;
    }

    private void loadOrCreateModel() {
        File modelFile = new File("models/deepfm.zip");
        if (modelFile.exists()) {
            // 加载已有模型
            log.info("Loading existing model from {}", modelFile.getAbsolutePath());
        } else {
            // 创建新模型
            log.info("Creating new model");
            // model = buildModel(numSparseFeatures, numDenseFeatures);
            // model.init();
        }
    }

    public void saveModel(String path) {
        log.info("Saving model to {}", path);
        // model.save(new File(path));
    }

    public long getPredictionCount() {
        return predictionCounter.get();
    }
}

来源:
https://tmywi.cn/category/meishi.html

相关文章
|
8天前
|
人工智能 数据可视化 安全
王炸组合!阿里云 OpenClaw X 飞书 CLI,开启 Agent 基建狂潮!(附带免费使用6个月服务器)
本文详解如何用阿里云Lighthouse一键部署OpenClaw,结合飞书CLI等工具,让AI真正“动手”——自动群发、生成科研日报、整理知识库。核心理念:未来软件应为AI而生,CLI即AI的“手脚”,实现高效、安全、可控的智能自动化。
34498 21
王炸组合!阿里云 OpenClaw X 飞书 CLI,开启 Agent 基建狂潮!(附带免费使用6个月服务器)
|
19天前
|
人工智能 JSON 机器人
让龙虾成为你的“公众号分身” | 阿里云服务器玩Openclaw
本文带你零成本玩转OpenClaw:学生认证白嫖6个月阿里云服务器,手把手配置飞书机器人、接入免费/高性价比AI模型(NVIDIA/通义),并打造微信公众号“全自动分身”——实时抓热榜、AI选题拆解、一键发布草稿,5分钟完成热点→文章全流程!
45353 142
让龙虾成为你的“公众号分身” | 阿里云服务器玩Openclaw
|
2天前
|
人工智能 自然语言处理 安全
Claude Code 全攻略:命令大全 + 实战工作流(建议收藏)
本文介绍了Claude Code终端AI助手的使用指南,主要内容包括:1)常用命令如版本查看、项目启动和更新;2)三种工作模式切换及界面说明;3)核心功能指令速查表,包含初始化、压缩对话、清除历史等操作;4)详细解析了/init、/help、/clear、/compact、/memory等关键命令的使用场景和语法。文章通过丰富的界面截图和场景示例,帮助开发者快速掌握如何通过命令行和交互界面高效使用Claude Code进行项目开发,特别强调了CLAUDE.md文件作为项目知识库的核心作用。
2877 8
Claude Code 全攻略:命令大全 + 实战工作流(建议收藏)
|
9天前
|
人工智能 JSON 监控
Claude Code 源码泄露:一份价值亿元的 AI 工程公开课
我以为顶级 AI 产品的护城河是模型。读完这 51.2 万行泄露的源码,我发现自己错了。
4989 21
|
2天前
|
人工智能 监控 安全
阿里云SASE 2.0升级,全方位监控Agent办公安全
AI Agent办公场景的“安全底座”
1136 1
|
8天前
|
人工智能 API 开发者
阿里云百炼 Coding Plan 售罄、Lite 停售、Pro 抢不到?最新解决方案
阿里云百炼Coding Plan Lite已停售,Pro版每日9:30限量抢购难度大。本文解析原因,并提供两大方案:①掌握技巧抢购Pro版;②直接使用百炼平台按量付费——新用户赠100万Tokens,支持Qwen3.5-Max等满血模型,灵活低成本。
1948 6
阿里云百炼 Coding Plan 售罄、Lite 停售、Pro 抢不到?最新解决方案