pytorch损失函数binary_cross_entropy和binary_cross_entropy_with_logits的区别

简介: binary_cross_entropy和binary_cross_entropy_with_logits都是来自torch.nn.functional的函数

binary_cross_entropybinary_cross_entropy_with_logits都是来自torch.nn.functional的函数,首先对比官方文档对它们的区别:


函数名 解释
binary_cross_entropy Function that measures the Binary Cross Entropy between the target and the output
binary_cross_entropy_with_logits Function that measures Binary Cross Entropy between target and output logits


区别只在于这个logits,那么这个logits是什么意思呢?以下是从网络上找到的一个答案:


有一个(类)损失函数名字中带了with_logits. 而这里的logits指的是,该损失函数已经内部自带了计算logit的操作,无需在传入给这个loss函数之前手动使用sigmoid/softmax将之前网络的输入映射到[0,1]之间


再看看官方给的示例代码:

binary_cross_entropy:

input = torch.randn((3, 2), requires_grad=True)
target = torch.rand((3, 2), requires_grad=False)
loss = F.binary_cross_entropy(F.sigmoid(input), target)
loss.backward()
# input is  tensor([[-0.5474,  0.2197],
#         [-0.1033, -1.3856],
#         [-0.2582, -0.1918]], requires_grad=True)
# target is  tensor([[0.7867, 0.5643],
#         [0.2240, 0.8263],
#         [0.3244, 0.2778]])
# loss is  tensor(0.8196, grad_fn=<BinaryCrossEntropyBackward>)

binary_cross_entropy_with_logits:

input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)
loss = F.binary_cross_entropy_with_logits(input, target)
loss.backward()
# input is  tensor([ 1.3210, -0.0636,  0.8165], requires_grad=True)
# target is  tensor([0., 1., 1.])
# loss is  tensor(0.8830, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

的确binary_cross_entropy_with_logits不需要sigmoid函数了。

事实上,官方是推荐使用函数带有with_logits的,解释是


This loss combines a Sigmoid layer and the BCELoss in one single class. This version is more numerically stable than using a plain Sigmoid followed by a BCELoss as, by combining the operations into one layer, we take advantage of the log-sum-exp trick for numerical stability.


翻译一下就是说将sigmoid层和binaray_cross_entropy合在一起计算比分开依次计算有更好的数值稳定性,这主要是运用了log-sum-exp技巧。

那么这个log-sum-exp主要就是讲如何防止数值计算溢出的问题:


l o g s u m e x p ( x 1 , x 2 , . . . , x n ) = l o g ( ∑ i = 1 n e x i ) logsumexp(x_1,x_2,...,x_n) = log(\sum_{i=1}^{n}e^{x_i}) logsumexp(x1,x2,...,xn)=log(i=1∑nexi)针对上述式子,如果 x i x_i xi很大,那么 e x i e^{x_i} exi很有可能会溢出,为了避免这样的问题,上式可以进行如下变换:


l o g ( ∑ i = 1 n e x i ) = l o g ( e c ∑ i = 1 n e x i − c ) = c l o g e + l o g ( ∑ i = 1 n e x i − c ) log(\sum_{i=1}^{n}e^{x_i})=log(e^c\sum_{i=1}^{n}e^{x_i-c})=cloge+log(\sum_{i=1}^{n}e^{x_i-c}) log(i=1∑nexi)=log(eci=1∑nexi−c)=cloge+log(i=1∑nexi−c)于是乎,这样就可以避免数据溢出了。


目录
相关文章
|
2月前
|
机器学习/深度学习 文字识别 PyTorch
PyTorch内置损失函数汇总 !!
PyTorch内置损失函数汇总 !!
44 0
|
4月前
|
机器学习/深度学习 PyTorch TensorFlow
|
4月前
|
数据挖掘 PyTorch 算法框架/工具
人脸识别中的损失函数ArcFace及其实现过程代码(pytorch)--理解softmax损失函数及Arcface
人脸识别中的损失函数ArcFace及其实现过程代码(pytorch)--理解softmax损失函数及Arcface
176 0
|
1月前
|
机器学习/深度学习 PyTorch 算法框架/工具
深度学习框架:Pytorch与Keras的区别与使用方法
深度学习框架:Pytorch与Keras的区别与使用方法
29 0
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch torch.nn库以及nn与nn.functional有什么区别?
Pytorch torch.nn库以及nn与nn.functional有什么区别?
46 0
|
9月前
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch学习笔记(6):模型的权值初始化与损失函数
Pytorch学习笔记(6):模型的权值初始化与损失函数
116 0
Pytorch学习笔记(6):模型的权值初始化与损失函数
|
11月前
|
PyTorch 算法框架/工具
【PyTorch】rand/randn/randint/randperm的区别
【PyTorch】rand/randn/randint/randperm的区别
63 0
|
11月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【PyTorch】nn.ReLU()与F.relu()的区别
【PyTorch】nn.ReLU()与F.relu()的区别
101 0