开发者社区 > 大数据与机器学习 > 人工智能平台PAI > 正文

机器学习PAI 负采样版本DSSM双塔模型训练好之后,怎么分别获取?

机器学习PAI 负采样版本DSSM双塔模型训练好之后,怎么分别获取user tower的embedding和item tower的embedding?

展开
收起
真的很搞笑 2023-11-19 14:58:23 99 0
3 条回答
写回答
取消 提交回答
  • 在训练好负采样版本的DSSM双塔模型后,可以通过以下步骤分别获取user tower和item tower的embedding:

    1. 首先,加载训练好的模型参数。假设模型保存在model_path路径下,可以使用以下代码加载模型参数:
    import torch
    from dssm import DSSMModel
    
    model = DSSMModel(...)  # 初始化模型
    model.load_state_dict(torch.load(model_path))
    model.eval()  # 设置为评估模式
    
    1. 然后,定义一个函数来获取指定user或item的embedding。假设有一个用户ID为user_id,可以定义如下函数来获取该用户的embedding:
    def get_user_embedding(user_id, model):
        with torch.no_grad():
            # 将user ID转换为tensor并扩展维度以适应模型输入要求
            user_id_tensor = torch.tensor([user_id]).unsqueeze(0).to(device)
            # 通过模型获取user tower的输出表示
            user_output = model.user_tower(user_id_tensor)
            # 返回user tower的embedding(即输出表示的第一个元素)
            return user_output[0].detach().cpu().numpy()
    

    类似地,可以定义另一个函数来获取指定item的embedding:

    def get_item_embedding(item_id, model):
        with torch.no_grad():
            # 将item ID转换为tensor并扩展维度以适应模型输入要求
            item_id_tensor = torch.tensor([item_id]).unsqueeze(0).to(device)
            # 通过模型获取item tower的输出表示
            item_output = model.item_tower(item_id_tensor)
            # 返回item tower的embedding(即输出表示的第一个元素)
            return item_output[0].detach().cpu().numpy()
    
    1. 最后,使用上述定义的函数分别获取user tower和item tower的embedding。例如,要获取用户ID为123的用户embedding和物品ID为456的物品embedding,可以执行以下操作:
    user_embedding = get_user_embedding(123, model)
    item_embedding = get_item_embedding(456, model)
    

    这样,就可以分别获得user tower和item tower的embedding了。

    2023-11-29 16:08:29
    赞同 展开评论 打赏
  • 面对过去,不要迷离;面对未来,不必彷徨;活在今天,你只要把自己完全展示给别人看。

    在 MaxCompute PAI 中,DSSM 双塔模型训练完成后,可以使用以下步骤获得 user 和 item 的 embedding:

    1. 首先,在训练模型时指定 --dump-model 参数,以保存模型和参数。
    2. 训练结束后,可以通过以下命令将模型参数导出为 JSON 格式:

      pai dump-model dssm_model --output output_path
      

      其中 dssm_model 是模型名称,output_path 是保存参数的路径。

    3. 使用以下命令将 JSON 格式的参数转换成 HDF5 格式,便于处理:

      ./model_converter.py model_path output_path
      

      其中 model_path 是 JSON 格式的参数文件路径,output_path 是 HDF5 格式的文件路径。

    4. 最后,使用以下命令将 HDF5 文件解压:

      h5py -d model.h5 model
      

      其中 model.h5 是 HDF5 文件路径,model 是模型名。

    2023-11-19 15:13:26
    赞同 展开评论 打赏

人工智能平台 PAI(Platform for AI,原机器学习平台PAI)是面向开发者和企业的机器学习/深度学习工程平台,提供包含数据标注、模型构建、模型训练、模型部署、推理优化在内的AI开发全链路服务,内置140+种优化算法,具备丰富的行业场景插件,为用户提供低门槛、高性能的云原生AI工程化能力。

相关产品

  • 人工智能平台 PAI
  • 热门讨论

    热门文章

    相关电子书

    更多
    微博机器学习平台架构和实践 立即下载
    机器学习及人机交互实战 立即下载
    大数据与机器学习支撑的个性化大屏 立即下载