ChatGLM2 源码解析:`GLMTransformer`

简介: ChatGLM2 源码解析:`GLMTransformer`

# 编码器模块,包含所有 GLM 块
class GLMTransformer(torch.nn.Module):
    """Transformer class."""
    def __init__(self, config: ChatGLMConfig, device=None):
        super(GLMTransformer, self).__init__()
        self.fp32_residual_connection = config.fp32_residual_connection
        self.post_layer_norm = config.post_layer_norm
        # LC
        self.num_layers = config.num_layers
        # TFBlock 层
        def build_layer(layer_number):
            return GLMBlock(config, layer_number, device=device)
        self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
        # 如果最后添加 LN,初始化 LN 层
        if self.post_layer_norm:
            LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
            # Final layer norm before output.
            self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                                 dtype=config.torch_dtype)
        self.gradient_checkpointing = False
    def _get_layer(self, layer_number):
        return self.layers[layer_number]
    def forward(
            self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
            use_cache: Optional[bool] = True,
            output_hidden_states: Optional[bool] = False,
    ):
        # 如果没有提供 KV 缓存,将其初始化为 [None] * LC 保持代码统一
        if not kv_caches:
            kv_caches = [None for _ in range(self.num_layers)]
        # `presents`保存每一层的 KV 的缓存
        presents = () if use_cache else None
        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False
        all_self_attentions = None
        # `all_hidden_states`保存输入和所有层的输出
        all_hidden_states = () if output_hidden_states else None
        # 输入 -> TFBlock1 -> TFBlock2 -> ... TFBLockN -> LN? -> 输出
        for index in range(self.num_layers):
            # 将当前一层的输入存入`all_hidden_states`
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)
            # 获取当前一层,将输入扔进去,得到输出和 KV 缓存
            layer = self._get_layer(index)
            if self.gradient_checkpointing and self.training:
                layer_ret = torch.utils.checkpoint.checkpoint(
                    layer,
                    hidden_states,
                    attention_mask,
                    rotary_pos_emb,
                    kv_caches[index],
                    use_cache
                )
            else:
                layer_ret = layer(
                    hidden_states,
                    attention_mask,
                    rotary_pos_emb,
                    kv_cache=kv_caches[index],
                    use_cache=use_cache
                )
            # 将输出作为新的输入
            hidden_states, kv_cache = layer_ret
            # 保存当前一层的 KV 缓存
            if use_cache:
                presents = presents + (kv_cache,)
        # 将最后一层的输出存入`all_hidden_states`
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)
        # 将最后一层的输出传给 LN 得到 GLM 输出
        if self.post_layer_norm:
            hidden_states = self.final_layernorm(hidden_states)
        # 返回 GLM 输出,所有层的 KV 缓存,所有层的输出,以及所有层的注意力矩阵(None)
        return hidden_states, presents, all_hidden_states, all_self_attentions


相关文章
|
3天前
|
XML 人工智能 Java
Spring Bean名称生成规则(含源码解析、自定义Spring Bean名称方式)
Spring Bean名称生成规则(含源码解析、自定义Spring Bean名称方式)
|
11天前
yolo-world 源码解析(六)(2)
yolo-world 源码解析(六)
23 0
|
11天前
yolo-world 源码解析(六)(1)
yolo-world 源码解析(六)
18 0
|
12天前
yolo-world 源码解析(五)(4)
yolo-world 源码解析(五)
23 0
|
12天前
yolo-world 源码解析(五)(1)
yolo-world 源码解析(五)
34 0
|
12天前
yolo-world 源码解析(二)(2)
yolo-world 源码解析(二)
23 0
|
26天前
|
XML Java Android开发
Android实现自定义进度条(源码+解析)
Android实现自定义进度条(源码+解析)
53 1
|
30天前
|
存储 NoSQL 算法
【Redis技术进阶之路】「底层源码解析」揭秘高效存储模型与数据结构底层实现(字典)(二)
【Redis技术进阶之路】「底层源码解析」揭秘高效存储模型与数据结构底层实现(字典)
47 0
|
12天前
Marker 源码解析(二)(3)
Marker 源码解析(二)
17 0

推荐镜像

更多