pytorch中meter.ClassErrorMeter()使用方法

简介: pytorch中meter.ClassErrorMeter()使用方法

PyTorchNet从TorchNet迁移而来,其中提供了很多有用的工具,例如meter,meter提供了一些轻量级的工具,帮助用户快速计算训练过程中的一些指标。

AverageValueMeter能够计算所有数的平均值和标准差,同意几个epoch中损失的平均值。

ClassErrorMeter能够计算每个epoch下的类别错误率,在分类任务经常使用。

下面我们介绍下如何使用ClassErrorMeter()这个方法计算每个epoch的图像分类准确率,对于这个目的,我们可以通过定义变量,然后不断累加每个批次的数据,然后进行计算,但是现在有一个更好的工具,可以帮助我们实现这个操作。

首先使用meter.ClassErrorMeter()实例化一个类,该类可以想成内部有一个集合,里面会保存一些数据,并定义一些方法能够对这些数据进行处理来满足我们的要求,说白了就是把我们正常计算指标的代码封装到一个类中。

error_meter = meter.ClassErrorMeter()

我们只需要调用类的add函数不断将数据添加到其中即可,该函数有两个参数,分别是outputtarget,第一个参数是模型的softmax输出结果,第二个参数是对应的标签。

error_meter.add(output.detach(), labels)

然后等一个epoch的所有结果全部填入其中之后,就可以使用error_meter.value获得结果

error_meter.value()

但是注意一个问题,他计算的是错误率,如果想要正确率,那么需要用100减去它即可。

而且还需要注意一个问题,当我们处理完一个epoch之后,需要清空当前的信息,只需要调用reset()即可。

下面使用一个示例来说明如何使用:

for epoch in range(20):
    model.train()
    for data in tqdm(train_loader):
        images, labels = data
        optimizer.zero_grad()
        output = model(images)
        loss = loss_function(output, labels)
        loss.backward()
        optimizer.step()
        loss_meter.add(loss.item())
        error_meter.add(output.detach(), labels)
     # 打印信息
    print("【EPOCH: 】%s" % str(epoch + 1))
    print("训练集损失为%s" % (str(loss_meter.mean)))
    print("训练集精度为%s" % (str(100 - error_meter.value()[0]) + '%'))
    loss_meter.reset()
    error_meter.reset()
    model.eval()
    for data in tqdm(val_loader):
        images, labels = data
        optimizer.zero_grad()
        output = model(images)
        loss = loss_function(output, labels)
        loss.backward()
        optimizer.step()
        loss_meter.add(loss.item())
        error_meter.add(output.detach(), labels)
    print("【EPOCH: 】%s" % str(epoch + 1))
    print("验证集损失为%s" % (str(loss_meter.mean)))
    print("验证集精度为%s" % (str(100 - error_meter.value()[0]) + '%'))
    loss_meter.reset()
    error_meter.reset()


目录
相关文章
|
PyTorch 算法框架/工具
pytorch中torch.clamp()使用方法
pytorch中torch.clamp()使用方法
608 0
pytorch中torch.clamp()使用方法
|
并行计算 PyTorch 测试技术
PyTorch 之 简介、相关软件框架、基本使用方法、tensor 的几种形状和 autograd 机制-2
由于要进行 tensor 的学习,因此,我们先导入我们需要的库。
|
机器学习/深度学习 人工智能 自然语言处理
PyTorch 之 简介、相关软件框架、基本使用方法、tensor 的几种形状和 autograd 机制-1
PyTorch 是一个基于 Torch 的 Python 开源机器学习库,用于自然语言处理等应用程序。它主要由 Facebook 的人工智能小组开发,不仅能够实现强大的 GPU 加速,同时还支持动态神经网络,这一点是现在很多主流框架如 TensorFlow 都不支持的。
|
机器学习/深度学习 人工智能 PyTorch
|
机器学习/深度学习 PyTorch 算法框架/工具
pytorch中nn.Parameter()使用方法
pytorch中nn.Parameter()使用方法
1461 1
|
PyTorch 算法框架/工具
pytorch中ImageFolder()使用方法
pytorch中ImageFolder()使用方法
358 0
pytorch中ImageFolder()使用方法
|
PyTorch 算法框架/工具 异构计算
基于Pytorch查看本地或者远程服务器GPU及使用方法
基于Pytorch查看本地或者远程服务器GPU及使用方法
526 0
基于Pytorch查看本地或者远程服务器GPU及使用方法
|
PyTorch 算法框架/工具
pytorch中keepdim参数归并操作使用方法
pytorch中keepdim参数归并操作使用方法
160 0
|
PyTorch 算法框架/工具
pytorch中meter.AverageValueMeter()使用方法
pytorch中meter.AverageValueMeter()使用方法
289 0
|
PyTorch 算法框架/工具
pytorch中nn.ModuleList()使用方法
pytorch中nn.ModuleList()使用方法
357 0

热门文章

最新文章