ChatGLM2 源码解析:`GLMBlock`

简介: ChatGLM2 源码解析:`GLMBlock`

# GLM 块包括注意力层、FFN层和之间的残差
class GLMBlock(torch.nn.Module):
    """A single transformer layer.
    Transformer layer takes input with size [s, b, h] and returns an
    output of the same size.
    """
    def __init__(self, config: ChatGLMConfig, layer_number, device=None):
        super(GLMBlock, self).__init__()
        self.layer_number = layer_number
        self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
        self.fp32_residual_connection = config.fp32_residual_connection
        # 判断使用 RMS 还是 LN
        LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
        # LN1
        self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                             dtype=config.torch_dtype)
        # 注意力层
        self.self_attention = SelfAttention(config, layer_number, device=device)
        # Dropout
        self.hidden_dropout = config.hidden_dropout
        # LLN2
        self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                                      dtype=config.torch_dtype)
        # FFN
        self.mlp = MLP(config, device=device)
    def forward(
            self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
    ):
        # hidden_states: [s, b, h]
        # 输入 -> LN1 -> 注意力层 -> ...
        layernorm_output = self.input_layernorm(hidden_states)
        attention_output, kv_cache = self.self_attention(
            layernorm_output,
            attention_mask,
            rotary_pos_emb,
            kv_cache=kv_cache,
            use_cache=use_cache
        )
        # 判断残差是否在LN1后面
        # 如果为真,那么:
        # 输入 -> LN1 -> 注意力 -> Dropout -> ⊕ -> ...
        #  |                                  ↑
        #  +----------------------------------+
        # 否则:
        # 输入 -> LN1 -> 注意力 -> Dropout -> ⊕ -> ...
        #          |                          ↑
        #          +--------------------------+
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states
        layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
        layernorm_input = residual + layernorm_input
        # ... -> LN2 -> FFN -> ...
        layernorm_output = self.post_attention_layernorm(layernorm_input)
        mlp_output = self.mlp(layernorm_output)
        # 判断残差是否在LN1后面
        # 如果为真,那么:
        # ... -> LN2 -> FFN -> Dropout -> ⊕ -> 输出
        #  |                               ↑
        #  +-------------------------------+
        # 否则:
        # ... -> LN2 -> FFN -> Dropout -> ⊕ -> 输出
        #         |                        ↑
        #         +------------------------+
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input
        output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
        output = residual + output
        return output, kv_cache
相关文章
|
8天前
yolo-world 源码解析(六)(2)
yolo-world 源码解析(六)
18 0
|
8天前
yolo-world 源码解析(六)(1)
yolo-world 源码解析(六)
12 0
|
9天前
yolo-world 源码解析(五)(4)
yolo-world 源码解析(五)
21 0
|
9天前
yolo-world 源码解析(五)(1)
yolo-world 源码解析(五)
31 0
|
9天前
yolo-world 源码解析(二)(2)
yolo-world 源码解析(二)
21 0
|
9天前
Marker 源码解析(二)(3)
Marker 源码解析(二)
14 0
|
23天前
|
XML Java Android开发
Android实现自定义进度条(源码+解析)
Android实现自定义进度条(源码+解析)
51 1
|
27天前
|
存储 NoSQL 算法
【Redis技术进阶之路】「底层源码解析」揭秘高效存储模型与数据结构底层实现(字典)(二)
【Redis技术进阶之路】「底层源码解析」揭秘高效存储模型与数据结构底层实现(字典)
43 0
|
9天前
Marker 源码解析(一)(4)
Marker 源码解析(一)
13 0

推荐镜像

更多