ChatGLM2 源码解析:`ChatGLMForConditionalGeneration.forward`

简介: ChatGLM2 源码解析:`ChatGLMForConditionalGeneration.forward`

class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
    def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
        super().__init__(config)
        self.max_sequence_length = config.max_length
        self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
        self.config = config
        self.quantized = False
        if self.config.quantization_bit:
            self.quantize(self.config.quantization_bit, empty_init=True)
    def forward(
            self,
            input_ids: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            labels: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            return_last_logit: Optional[bool] = False,
    ):
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        transformer_outputs = self.transformer(
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]
        # `return_last_logit`表示只保留最后一个单词的
        if return_last_logit:
            hidden_states = hidden_states[-1:]
        # 将编码器输出传入输出层得到单词概率
        lm_logits = self.transformer.output_layer(hidden_states)
        # [SL, BS, ...] => [BS, SL, ...]
        lm_logits = lm_logits.transpose(0, 1).contiguous()
        loss = None
        if labels is not None:
            lm_logits = lm_logits.to(torch.float32)
            # 让第 i 个词前面的单词预测第 i 个词
            # 假如原文是 [A, B, C, D, E]
            # logits = [A, B, C, D],labels = [B, C, D, E]
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # 单词 Logits 变形为 [BS * (SL - 1), VS]
            # 标签变形为 [BS * (SL - 1)]
            # 计算交叉熵
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            lm_logits = lm_logits.to(hidden_states.dtype)
            loss = loss.to(hidden_states.dtype)
        # 返回损失、单词 Logits、KV 缓存、编码器输出、以及编码器注意力矩阵
        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output
        return CausalLMOutputWithPast(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )


相关文章
|
6天前
yolo-world 源码解析(六)(2)
yolo-world 源码解析(六)
16 0
|
6天前
yolo-world 源码解析(六)(1)
yolo-world 源码解析(六)
8 0
|
6天前
yolo-world 源码解析(五)(4)
yolo-world 源码解析(五)
16 0
|
6天前
yolo-world 源码解析(五)(1)
yolo-world 源码解析(五)
29 0
|
6天前
yolo-world 源码解析(二)(2)
yolo-world 源码解析(二)
20 0
|
6天前
Marker 源码解析(二)(3)
Marker 源码解析(二)
10 0
|
21天前
|
XML Java Android开发
Android实现自定义进度条(源码+解析)
Android实现自定义进度条(源码+解析)
50 1
|
24天前
|
存储 NoSQL 算法
【Redis技术进阶之路】「底层源码解析」揭秘高效存储模型与数据结构底层实现(字典)(二)
【Redis技术进阶之路】「底层源码解析」揭秘高效存储模型与数据结构底层实现(字典)
36 0
|
6天前
Marker 源码解析(一)(4)
Marker 源码解析(一)
11 0

推荐镜像

更多