想要神经网络输出分类的概率值?应该这样写代码

简介: 想要神经网络输出分类的概率值?应该这样写代码

我们构造一个简单的神经网络,通常情况下n_output是分类数量,例如二分类任务那n_output=2、六分类任务那么n_output=6


class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.inLayer = torch.nn.Linear(n_feature, n_hidden) # 输入层
        self.hiddenLayer = torch.nn.Linear(n_hidden, n_hidden) # 隐藏层
        self.outLayer = torch.nn.Linear(n_hidden, n_output) # 输出层
    # 前向计算函数,定义完成后会隐式地自动生成反向传播函数
    def forward(self, x):
        x = F.relu(self.hiddenLayer(self.inLayer(x)))
        x = self.outLayer(x)
        return x


假设要进行一个二分类任务,这时的输出x的值是类别0和类别1的预测值,我们可以使用softmax函数将其转换为概率值:


class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.inLayer = torch.nn.Linear(n_feature, n_hidden) # 输入层
        self.hiddenLayer = torch.nn.Linear(n_hidden, n_hidden) # 隐藏层
        self.outLayer = torch.nn.Linear(n_hidden, n_output) # 输出层
    # 前向计算函数,定义完成后会隐式地自动生成反向传播函数
    def forward(self, x):
        x = F.relu(self.hiddenLayer(self.inLayer(x)))
        x = self.outLayer(x)
        prob = F.softmax(x, dim=1) # softmax将预测值转换为概率
        out = torch.unsqueeze(prob.argmax(dim=1), dim=1) # argmax选取概率最大的预测项,unsqueeze扩维
        return out


输出结果例如:


# prob = F.softmax(x, dim=1)的输出结果
# print(prob)
tensor([[0.5059, 0.4941],
        [0.5018, 0.4982],
        [0.4886, 0.5114]], grad_fn=<SoftmaxBackward0>)
# out = torch.unsqueeze(prob.argmax(dim=1), dim=1)的输出结果
# print(out)
tensor([[0],
        [0],
        [1]])
相关文章
|
29天前
|
机器学习/深度学习 算法 PyTorch
RPN(Region Proposal Networks)候选区域网络算法解析(附PyTorch代码)
RPN(Region Proposal Networks)候选区域网络算法解析(附PyTorch代码)
217 1
|
1月前
|
机器学习/深度学习 人工智能 自然语言处理
ICLR 2024 Spotlight:训练一个图神经网络即可解决图领域所有分类问题!
【2月更文挑战第17天】ICLR 2024 Spotlight:训练一个图神经网络即可解决图领域所有分类问题!
59 2
ICLR 2024 Spotlight:训练一个图神经网络即可解决图领域所有分类问题!
|
2月前
|
机器学习/深度学习 编解码 文件存储
YOLOv8改进 | 2023Neck篇 | BiFPN双向特征金字塔网络(附yaml文件+代码)
YOLOv8改进 | 2023Neck篇 | BiFPN双向特征金字塔网络(附yaml文件+代码)
255 0
|
2月前
|
机器学习/深度学习 算法 PyTorch
python手把手搭建图像多分类神经网络-代码教程(手动搭建残差网络、mobileNET)
python手把手搭建图像多分类神经网络-代码教程(手动搭建残差网络、mobileNET)
46 0
|
2月前
|
机器学习/深度学习 测试技术 Ruby
YOLOv8改进 | 主干篇 | 反向残差块网络EMO一种轻量级的CNN架构(附完整代码 + 修改教程)
YOLOv8改进 | 主干篇 | 反向残差块网络EMO一种轻量级的CNN架构(附完整代码 + 修改教程)
89 0
|
2月前
|
机器学习/深度学习 测试技术 Ruby
YOLOv5改进 | 主干篇 | 反向残差块网络EMO一种轻量级的CNN架构(附完整代码 + 修改教程)
YOLOv5改进 | 主干篇 | 反向残差块网络EMO一种轻量级的CNN架构(附完整代码 + 修改教程)
128 2
|
6天前
|
机器学习/深度学习 数据采集 TensorFlow
R语言KERAS深度学习CNN卷积神经网络分类识别手写数字图像数据(MNIST)
R语言KERAS深度学习CNN卷积神经网络分类识别手写数字图像数据(MNIST)
23 0
|
7天前
|
机器学习/深度学习 数据可视化 网络架构
matlab使用长短期记忆(LSTM)神经网络对序列数据进行分类
matlab使用长短期记忆(LSTM)神经网络对序列数据进行分类
12 0
|
15天前
|
传感器 监控 安全
|
29天前
|
机器学习/深度学习 PyTorch 算法框架/工具
卷积神经元网络中常用卷积核理解及基于Pytorch的实例应用(附完整代码)
卷积神经元网络中常用卷积核理解及基于Pytorch的实例应用(附完整代码)
20 0

热门文章

最新文章