【Ignite实践】更少的代码量训练模型(一)

简介: 【Ignite实践】更少的代码量训练模型(一)

1、Ignite的简介


ignite是一个高级库,可帮助你在PyTorch中训练神经网络。

   1、ignite可帮助您用几行代码编写紧凑而功能齐全的培训循环

   2、您将获得一个包含指标,提前停止,模型检查点和其他功能的训练循环,而无需样板

下面我们展示了使用纯pytorch和使用ignite创建训练循环来训练和验证您的模型的偶然比较,并偶尔进行了检查:

如图可以看出,带有ignite的代码更加简洁和易读。此外,添加额外的度量标准或提早停止之类的事情虽然轻而易举,但是当“迭代的”训练循环时,可能会开始迅速增加代码的复杂性。


ignite主要亮点功能:

     对于训练过程中的for循环,精简代码,提供度量,提前终止,保存模型,提供基于visdom和tensorBoardX的训练可视化。


2、Ignite各模块使用介绍


2.1、Engine

   ignite框架最基本的概念,循环一定的次数,循环的过程为基于训练数据,更新模型的参数。也可以加上评估的过程,基于验证数据集,计算损失函数的值。示例代码:

def update_model(engine, batch):
    inputs, targets = batch
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    return loss.item()
trainer = Engine(update_model)
@trainer.on(Events.ITERATION_COMPLETED(every=100))
def log_training(engine):
    batch_loss = engine.state.output
    lr = optimizer.param_groups[0]['lr']
    e = engine.state.epoch
    n = engine.state.max_epochs
    i = engine.state.iteration
    print("Epoch {}/{} : {} - batch loss: {}, lr: {}".format(e, n, i, batch_loss, lr))
trainer.run(data_loader, max_epochs=5)
> Epoch 1/5 : 100 - batch loss: 0.10874069479016124, lr: 0.01
> ...
> Epoch 2/5 : 1700 - batch loss: 0.4217900575859437, lr: 0.01


2.2、Events and Handlers

为了在训练过程中和外界进行交互,引用事件的机制。

事件触发的时间点包括:

      engine 的开始,结束

      epoch 的开始,结束

      batch iteration的开始,结束

用户注册事件的处理函数Handler,处理函数Handler在框架的事件触发时,会被回调。注册有两种方式:

       add_event_handler()

       on注解器

示例代码:

engine = Engine(process_function)
def print_epoch(engine):
    print("Epoch: {}".format(engine.state.epoch))
engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch)


2.3、时间轴

   如下是框架的时间轴,我们主要理解以下:

   可以注册epoch结束时的处理函数,在此函数中可以进行在验证数据集上的验证过程,判断是否进行提早终止训练,或者更新学习率(想必每个知道深度神经网络的同学应该都知道动态学习率的概念)。

image.png


2.4、state

   Engine类中包含了 State的对象。用于在事件处理程序之间传递内部状态和用户定义的状态。默认情况下,状态包含以下属性:

state.iteration         # 训练结束后的迭代次数

state.epoch             # 当前的轮数

state.seed              # 每个epoch的seed

state.dataloader        # engine的dataloader

state.epoch_length      # epoch的可选长度

state.max_epochs        # 训练的最大轮数

state.batch             # `process_function`的batch

state.output            # 训练结束后,在Engine中定义的处理函数的输出

state.metrics           # 度量方式

def update(engine, batch):
    x, y = batch
    y_pred = model(inputs)
    loss = loss_fn(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()
def on_iteration_completed(engine):
    iteration = engine.state.iteration
    epoch = engine.state.epoch
    loss = engine.state.output
    print("Epoch: {}, Iteration: {}, Loss: {}".format(epoch, iteration, loss))
trainer.add_event_handler(Events.ITERATION_COMPLETED, on_iteration_completed)

这个例子中,engine.state.output保存了损失函数的值。

engine.state.output保存的值是在Engine中定义的处理函数的输出,这个输出的类型是没有明确的,所以我们可以灵活使用。


在看一个例子:

def update(engine, batch):
    x, y = batch
    y_pred = model(inputs)
    loss = loss_fn(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item(), y_pred, y
trainer = Engine(update)
@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):
    epoch = engine.state.epoch
    loss = engine.state.output[0]
    print ('Epoch {epoch}: train_loss = {loss}'.format(epoch=epoch, loss=loss))
accuracy = Accuracy(output_transform=lambda x: [x[1], x[2]])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)

这个例子中,在Engine中定义的处理函数的输出,也就是update函数的返回值为一个tuple:loss,y_pred, y


对比,看这个例子:

def update(engine, batch):
    x, y = batch
    y_pred = model(inputs)
    loss = loss_fn(y_pred, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return {'loss': loss.item(),
            'y_pred': y_pred,
            'y': y}
trainer = Engine(update)
@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):
    epoch = engine.state.epoch
    loss = engine.state.output['loss']
    print ('Epoch {epoch}: train_loss = {loss}'.format(epoch=epoch, loss=loss))
accuracy = Accuracy(output_transform=lambda x: [x['y_pred'], x['y']])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)

在Engine中定义的处理函数的输出,也就是update函数的返回值为一个字典:loss,y_pred,y。所以其他地方访问engine.state.output中的数据时,需要按照字典的方式。

相关文章
|
Linux
【PyAutoGUI操作指南】05 屏幕截图与图像定位:截图+定位单个目标+定位全部目标+灰度匹配+像素匹配+获取屏幕截图中像素的RGB颜色
【PyAutoGUI操作指南】05 屏幕截图与图像定位:截图+定位单个目标+定位全部目标+灰度匹配+像素匹配+获取屏幕截图中像素的RGB颜色
2018 0
|
搜索推荐 Linux 定位技术
|
10月前
|
数据挖掘 Linux 索引
服务器数据恢复—服务器意外断电导致数据丢失的数据恢复案例
一台安装linux系统的服务器意外断电。管理员重启服务器后进行检测,发现服务器上部分文件丢失。管理员没有进行任何操作,直接将服务器正常关机并切断电源。
|
Shell Docker 容器
5-17|gitlab的runner什么意思
5-17|gitlab的runner什么意思
石英晶体是如何产生振荡的?以及cpu倍频的由来
本文是关于石英晶体振荡器的学习笔记,适合计算机科学与技术背景的读者。内容涵盖了石英晶体振荡器的基本原理,包括压电效应、等效电路、谐振频率,以及不同类型振荡器的特性和参数。此外,还讨论了石英晶体振荡器的小型化、高精度、低噪声、低功耗发展趋势,并列举了它们在石英钟、彩电和通信系统中的应用。最后提到了处理器倍频的概念,解释了其原理和实际应用中的限制。
石英晶体是如何产生振荡的?以及cpu倍频的由来
|
关系型数据库 MySQL 数据库
mysql的用户管理和权限控制
本文介绍了MySQL中用户管理的基本操作,包括创建用户、修改密码、删除用户、查询权限、授予权限和撤销权限的方法。
509 2
|
存储 NoSQL 数据库
Harbor 共享后端高可用-简单版
主机配置包括3台服务器,运行Harbor v2.10.0和Docker 24.0.5,其中10.0.90.68额外运行Postgres+Redis。基础安装配置中详细描述了Docker的安装步骤,包括添加仓库、安装、配置国内镜像源和启动Docker。安装postgres+redis服务使用docker-compose.yml文件,通过`docker-compose up -d`命令启动。最后,安装Harbor涉及修改harbor.yml配置文件,设置主机名、数据库和Redis连接信息,然后运行`install.sh`脚本。
429 3
|
机器学习/深度学习 存储 分布式计算
解释 Spark 在 Databricks 中的使用方式
【8月更文挑战第12天】
731 1
|
存储 网络架构
Vue3-admin-element框架学习笔记----5(最终篇--动态路由)
Vue3-admin-element框架学习笔记----5(最终篇--动态路由)
335 0
|
C语言 计算机视觉 Python
【Qt】Qt下配置OpenCV
【Qt】Qt下配置OpenCV
449 3

热门文章

最新文章