Mixtral MOE 部分源码解析

本文涉及的产品
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
云解析 DNS,旗舰版 1个月
简介: Mixtral MOE 部分源码解析
# 单个专家的架构,就是经典的 FFN
class MixtralBLockSparseTop2MLP(nn.Module):
    def __init__(self, config: MixtralConfig):
        super().__init__()
        # FFNSize,一般是 HidSize x4
        self.ffn_dim = config.intermediate_size
        # HidSize,隐藏状态的向量尺寸
        self.hidden_dim = config.hidden_size
        # 用于隐藏状态扩张的线性层
        self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
        # 用于隐藏状态收缩的线性层
        self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
        # 用于计算隐藏状态门控的线性层
        self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]
    def forward(self, hidden_states):
        # 输入隐藏状态的形状为 [BatchSize, SeqLen, HidSize]、
        # 输入经过第三个线性层并激活,得到门控
        # 输入经过第一个线性层,乘以门控,经过第二个线性层,得到输出
        current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
        current_hidden_states = self.w2(current_hidden_states)
        return current_hidden_states
# MOE 的架构
class MixtralSparseMoeBlock(nn.Module):
    """
    This implementation is
    strictly equivalent to standard MoE with full capacity (no
    dropped tokens). It's faster since it formulates MoE operations
    in terms of block-sparse operations to accomodate imbalanced
    assignments of tokens to experts, whereas standard MoE either
    (1) drop tokens at the cost of reduced performance or (2) set
    capacity factor to number of experts and thus waste computation
    and memory on padding.
    """
    def __init__(self, config):
        super().__init__()
        # HidSize,隐藏状态的向量尺寸
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.intermediate_size
        # NExp,专家数量
        self.num_experts = config.num_local_experts
        # TopK,激活的专家数量
        self.top_k = config.num_experts_per_tok
        # 门控线性层
        self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
        # 专家模块列表,每个都是 FFN
        self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)])
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """ """
        # 输入尺寸:[BatchSize, SeqLen, HidSize]
        # 获取 BatchSize(批量大小)
        #     SeqLen(序列长度)
        #     HidSize(隐藏状态尺寸)
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        # 将输入前两维合并,[BatchSize * SeqLen, HidSize]
        hidden_states = hidden_states.view(-1, hidden_dim)
        # 将隐藏状态传入门控线性层得到专家得分
        # 每个样本的每个单词都有一组得分
        # [BatchSize * SeqLen, NExp]
        router_logits = self.gate(hidden_states)
        # 专家得分经过 Softmax 得到专家概率
        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        # 计算每个得分的 TOPK,得到专家索引
        # routing_weights:TOPK 专家概率,[BatchSize * SeqLen, TopK]
        # selected_experts:TOPK 专家索引,[BatchSize * SeqLen, TopK]
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        # 专家概率归一化,使每组得分和为一
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        # 转换为输入的数据类型
        routing_weights = routing_weights.to(hidden_states.dtype)
        # 将最终的隐藏状态初始化为零,用于累加
        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )
        # 将专家索引单热化,交换前后两维,得到专家的掩码
        # [NExp, TopK, BatchSize * SeqLen]
        # mask[i, j, k] 表示第 k 个单词的第 j 个专家是不是专家 i
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
        # 遍历每个专家,expert_idx 为专家索引
        for expert_idx in range(self.num_experts):
            # 获取当前专家模块
            expert_layer = self.experts[expert_idx]
            # 使用索引来索引掩码,得到当前专家的掩码矩阵
            # [TopK, BatchSize * SeqLen]
            # 它的元素 [i, j] 表示第 j 个样本的第 i 个专家是不是当前专家
            # where 计算调用该专家的单词序号(top_x),以及该专家的排名(idx)
            idx, top_x = torch.where(expert_mask[expert_idx])
            # 如果没有单词调用该专家,转到下一个
            if top_x.shape[0] == 0:
                continue
            # 转 Python 列表
            top_x_list = top_x.tolist()
            idx_list = idx.tolist()
            # 获取调用该专家的单词的隐藏状态,[NHid, HidSize]
            current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
            # 将隐藏状态传入当前专家,得到专家输出,[NHid, HidSize]
            # 获取调用该专家的单词的专家概率,[NHid, 1]
            # 二者相乘
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
            # 将隐藏状态加到最终隐藏状态
            # 即 final_hidden_states[top_x[i]] += current_hidden_states[i]
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
        # 拆分第一维,[BatchSize, SeqLen, HidSize]
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
        return final_hidden_states, router_logits


相关文章
|
1月前
|
监控 Java 应用服务中间件
高级java面试---spring.factories文件的解析源码API机制
【11月更文挑战第20天】Spring Boot是一个用于快速构建基于Spring框架的应用程序的开源框架。它通过自动配置、起步依赖和内嵌服务器等特性,极大地简化了Spring应用的开发和部署过程。本文将深入探讨Spring Boot的背景历史、业务场景、功能点以及底层原理,并通过Java代码手写模拟Spring Boot的启动过程,特别是spring.factories文件的解析源码API机制。
71 2
|
2月前
|
缓存 Java 程序员
Map - LinkedHashSet&Map源码解析
Map - LinkedHashSet&Map源码解析
76 0
|
17天前
|
PyTorch Shell API
Ascend Extension for PyTorch的源码解析
本文介绍了Ascend对PyTorch代码的适配过程,包括源码下载、编译步骤及常见问题,详细解析了torch-npu编译后的文件结构和三种实现昇腾NPU算子调用的方式:通过torch的register方式、定义算子方式和API重定向映射方式。这对于开发者理解和使用Ascend平台上的PyTorch具有重要指导意义。
|
22天前
|
缓存 监控 Java
Java线程池提交任务流程底层源码与源码解析
【11月更文挑战第30天】嘿,各位技术爱好者们,今天咱们来聊聊Java线程池提交任务的底层源码与源码解析。作为一个资深的Java开发者,我相信你一定对线程池并不陌生。线程池作为并发编程中的一大利器,其重要性不言而喻。今天,我将以对话的方式,带你一步步深入线程池的奥秘,从概述到功能点,再到背景和业务点,最后到底层原理和示例,让你对线程池有一个全新的认识。
50 12
|
1月前
|
存储 安全 Linux
Golang的GMP调度模型与源码解析
【11月更文挑战第11天】GMP 调度模型是 Go 语言运行时系统的核心部分,用于高效管理和调度大量协程(goroutine)。它通过少量的操作系统线程(M)和逻辑处理器(P)来调度大量的轻量级协程(G),从而实现高性能的并发处理。GMP 模型通过本地队列和全局队列来减少锁竞争,提高调度效率。在 Go 源码中,`runtime.h` 文件定义了关键数据结构,`schedule()` 和 `findrunnable()` 函数实现了核心调度逻辑。通过深入研究 GMP 模型,可以更好地理解 Go 语言的并发机制。
|
1月前
|
消息中间件 缓存 安全
Future与FutureTask源码解析,接口阻塞问题及解决方案
【11月更文挑战第5天】在Java开发中,多线程编程是提高系统并发性能和资源利用率的重要手段。然而,多线程编程也带来了诸如线程安全、死锁、接口阻塞等一系列复杂问题。本文将深度剖析多线程优化技巧、Future与FutureTask的源码、接口阻塞问题及解决方案,并通过具体业务场景和Java代码示例进行实战演示。
54 3
|
2月前
|
存储
让星星⭐月亮告诉你,HashMap的put方法源码解析及其中两种会触发扩容的场景(足够详尽,有问题欢迎指正~)
`HashMap`的`put`方法通过调用`putVal`实现,主要涉及两个场景下的扩容操作:1. 初始化时,链表数组的初始容量设为16,阈值设为12;2. 当存储的元素个数超过阈值时,链表数组的容量和阈值均翻倍。`putVal`方法处理键值对的插入,包括链表和红黑树的转换,确保高效的数据存取。
63 5
|
2月前
|
Java Spring
Spring底层架构源码解析(三)
Spring底层架构源码解析(三)
147 5
|
2月前
|
XML Java 数据格式
Spring底层架构源码解析(二)
Spring底层架构源码解析(二)
|
2月前
|
算法 Java 程序员
Map - TreeSet & TreeMap 源码解析
Map - TreeSet & TreeMap 源码解析
42 0

推荐镜像

更多
下一篇
DataWorks