ChatGLM2 源码解析:`MLP`

简介: ChatGLM2 源码解析:`MLP`

class MLP(torch.nn.Module):
    """MLP.
    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """
    def __init__(self, config: ChatGLMConfig, device=None):
        super(MLP, self).__init__()
        self.add_bias = config.add_bias_linear
        # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
        # LL1,最后一维 ES => 4ES
        self.dense_h_to_4h = nn.Linear(
            config.hidden_size,
            config.ffn_hidden_size * 2,
            bias=self.add_bias,
            device=device,
            **_config_to_kwargs(config)
        )
        def swiglu(x):
            x = torch.chunk(x, 2, dim=-1)
            return F.silu(x[0]) * x[1]
        self.activation_func = swiglu
        # LL2,最后一维 4ES => ES
        self.dense_4h_to_h = nn.Linear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=self.add_bias,
            device=device,
            **_config_to_kwargs(config)
        )
    def forward(self, hidden_states):
        # 输入 -> LL1 -> swiglu -> LL2 -> 输出
        intermediate_parallel = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        output = self.dense_4h_to_h(intermediate_parallel)
        return output
相关文章
|
19天前
|
XML Java Android开发
Android实现自定义进度条(源码+解析)
Android实现自定义进度条(源码+解析)
50 1
|
23天前
|
存储 NoSQL 算法
【Redis技术进阶之路】「底层源码解析」揭秘高效存储模型与数据结构底层实现(字典)(二)
【Redis技术进阶之路】「底层源码解析」揭秘高效存储模型与数据结构底层实现(字典)
35 0

推荐镜像

更多