Caffe Softmax 层的实现原理【细节补充】

简介: Caffe Softmax 层的实现原理【细节补充】

本文是看了知乎的这篇文章以后觉得作者写的很好,但是有些细节讲解得不够详细,回复里面大家也多有疑问,特加以补充:


为了对原作者表示尊重和感谢,先注明原作出处:


作者:John Wang


链接:https://www.zhihu.com/question/28927103/answer/78810153



作者原文和我的补充


====================================


设 z 是 softmax loss 层的输入,f(z)是 softmax 的输出,即


image.png

y 是输入样本 z 对应的类别,y=0,1,...,N

对于 z ,其损失函数定义为

image.png

展开上式:

image.png

对上式求导,有

image.png

梯度下降方向即为

image.png

====================================

增加关于 softmax 层的反向传播说明

设 softmax 的输出为 a ,输入为 z ,损失函数为 l

image.png

image.png

其中

image.png在 caffe 中是 top_diff,a 为 caffe 中得 top_data,需要计算的是image.png

image.png  if i!=k

image.pngif i==k

【我的补充】

----------------------------------------------------------------

image.png

image.png

----------------------------------------------------------------

于是

image.png

【我的补充】

----------------------------------------------------------------

image.png

image.png

image.png

image.png

把负号提出去,改为点乘,即得到上式。注意,这里的 n 表示 channels,这里的 k 和 caffe 源码中的 k 含义不同。

----------------------------------------------------------------

整理一下得到

image.png

其中image.png表示将标量扩展为 n 维向量,表示向量按元素相乘


【我的补充】


----------------------------------------------------------------


这边作者讲解得有误,因为对照代码可以发现,点乘后其实得到的是 1*inner_num  大小的向量,所以为了对应通道相减,需要将其扩展为 channels*inner_num 的矩阵,而不是 n 维向量。


最后矩阵再按元素进行相乘。



 对照 caffe 源码

  // top_diff : l 对 a 向量求偏导
  // top_data :a 向量
  // 将 top_diff 拷贝到 bottom_diff
  // dim = channels * inner_num_
  // inner_num_ = height * width
  caffe_copy(top[0]->count(), top_diff, bottom_diff);
  // 遍历一个 batch 中的样本
  for (int i = 0; i < outer_num_; ++i) {
    // compute dot(top_diff, top_data) and subtract them from the bottom diff
    // 此处计算两个向量的点积,注意 top_diff 已经拷贝到 bottom_diff 当中
    // 步长为 inner_num_(跨通道)构造一个长度为 channels (类别个数)的向量,进行点乘
    for (int k = 0; k < inner_num_; ++k) {
      scale_data[k] = caffe_cpu_strided_dot<Dtype>(channels,
          bottom_diff + i * dim + k, inner_num_,
          top_data + i * dim + k, inner_num_);
    }
    // subtraction
    // 此处计算大括号内的减法(即负号)
    // 将 scale_data 扩展为 channels 个通道(多少个类别),再和 bottom_diff 对应的通道相减
    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, inner_num_, 1,
        -1., sum_multiplier_.cpu_data(), scale_data, 1., bottom_diff + i * dim);
  }
  // elementwise multiplication
  // 元素级的乘法
  // 此处计算大括号外和 a 向量的乘法
  caffe_mul(top[0]->count(), bottom_diff, top_data, bottom_diff)
目录
相关文章
|
存储 并行计算 Linux
国产之路:复旦微FMQL调试笔记1:PS网口
FPGA,全程现场可编程门阵列,是指一切通过软件手段更改、配置器件内部连接结构和逻辑单元,完成既定设计功能的数字集成电路。换个简单通俗的介绍方式,就好比一个全能的运动员,FPGA就是这么神奇的可以通过设定而实现各种复杂的功能电路。
2514 0
国产之路:复旦微FMQL调试笔记1:PS网口
SQLSTATE[42S02]: Base table or view not found: 1146 Table ‘thinkphp.test‘ don‘t exsit
SQLSTATE[42S02]: Base table or view not found: 1146 Table ‘thinkphp.test‘ don‘t exsit
702 0
|
Java 数据库连接 数据库
Mybatis系列(四)之Mybatis与Spring整合以及Aop整合pagehelper插件
Mybatis系列(四)之Mybatis与Spring整合以及Aop整合pagehelper插件
|
5天前
|
存储 人工智能 安全
AI 越智能,数据越危险?
阿里云提供AI全栈安全能力,为客户构建全链路数据保护体系,让企业敢用、能用、放心用
|
8天前
|
域名解析 人工智能
【实操攻略】手把手教学,免费领取.CN域名
即日起至2025年12月31日,购买万小智AI建站或云·企业官网,每单可免费领1个.CN域名首年!跟我了解领取攻略吧~
|
7天前
|
数据采集 人工智能 自然语言处理
3分钟采集134篇AI文章!深度解析如何通过云无影AgentBay实现25倍并发 + LlamaIndex智能推荐
结合阿里云无影 AgentBay 云端并发采集与 LlamaIndex 智能分析,3分钟高效抓取134篇 AI Agent 文章,实现 AI 推荐、智能问答与知识沉淀,打造从数据获取到价值提炼的完整闭环。
448 93
|
1天前
|
开发者
「玩透ESA」ESA启用和加速-ER在加速场景中的应用
本文介绍三种配置方法:通过“A鉴权”模板创建函数并设置触发器路由;在ESA上配置回源302跟随;以及自定义响应头。每步均配有详细截图指引,帮助开发者快速完成相关功能设置,提升服务安全性与灵活性。
286 2