1、Ignite的简介
ignite是一个高级库,可帮助你在PyTorch中训练神经网络。
1、ignite可帮助您用几行代码编写紧凑而功能齐全的培训循环
2、您将获得一个包含指标,提前停止,模型检查点和其他功能的训练循环,而无需样板
下面我们展示了使用纯pytorch和使用ignite创建训练循环来训练和验证您的模型的偶然比较,并偶尔进行了检查:
如图可以看出,带有ignite的代码更加简洁和易读。此外,添加额外的度量标准或提早停止之类的事情虽然轻而易举,但是当“迭代的”训练循环时,可能会开始迅速增加代码的复杂性。
ignite主要亮点功能:
对于训练过程中的for循环,精简代码,提供度量,提前终止,保存模型,提供基于visdom和tensorBoardX的训练可视化。
2、Ignite各模块使用介绍
2.1、Engine
ignite框架最基本的概念,循环一定的次数,循环的过程为基于训练数据,更新模型的参数。也可以加上评估的过程,基于验证数据集,计算损失函数的值。示例代码:
def update_model(engine, batch): inputs, targets = batch optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() return loss.item() trainer = Engine(update_model) @trainer.on(Events.ITERATION_COMPLETED(every=100)) def log_training(engine): batch_loss = engine.state.output lr = optimizer.param_groups[0]['lr'] e = engine.state.epoch n = engine.state.max_epochs i = engine.state.iteration print("Epoch {}/{} : {} - batch loss: {}, lr: {}".format(e, n, i, batch_loss, lr)) trainer.run(data_loader, max_epochs=5) > Epoch 1/5 : 100 - batch loss: 0.10874069479016124, lr: 0.01 > ... > Epoch 2/5 : 1700 - batch loss: 0.4217900575859437, lr: 0.01
2.2、Events and Handlers
为了在训练过程中和外界进行交互,引用事件的机制。
事件触发的时间点包括:
engine 的开始,结束
epoch 的开始,结束
batch iteration的开始,结束
用户注册事件的处理函数Handler,处理函数Handler在框架的事件触发时,会被回调。注册有两种方式:
add_event_handler()
on注解器
示例代码:
engine = Engine(process_function) def print_epoch(engine): print("Epoch: {}".format(engine.state.epoch)) engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch)
2.3、时间轴
如下是框架的时间轴,我们主要理解以下:
可以注册epoch结束时的处理函数,在此函数中可以进行在验证数据集上的验证过程,判断是否进行提早终止训练,或者更新学习率(想必每个知道深度神经网络的同学应该都知道动态学习率的概念)。
2.4、state
Engine类中包含了 State的对象。用于在事件处理程序之间传递内部状态和用户定义的状态。默认情况下,状态包含以下属性:
state.iteration # 训练结束后的迭代次数
state.epoch # 当前的轮数
state.seed # 每个epoch的seed
state.dataloader # engine的dataloader
state.epoch_length # epoch的可选长度
state.max_epochs # 训练的最大轮数
state.batch # `process_function`的batch
state.output # 训练结束后,在Engine中定义的处理函数的输出
state.metrics # 度量方式
def update(engine, batch): x, y = batch y_pred = model(inputs) loss = loss_fn(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() return loss.item() def on_iteration_completed(engine): iteration = engine.state.iteration epoch = engine.state.epoch loss = engine.state.output print("Epoch: {}, Iteration: {}, Loss: {}".format(epoch, iteration, loss)) trainer.add_event_handler(Events.ITERATION_COMPLETED, on_iteration_completed)
这个例子中,engine.state.output保存了损失函数的值。
engine.state.output保存的值是在Engine中定义的处理函数的输出,这个输出的类型是没有明确的,所以我们可以灵活使用。
在看一个例子:
def update(engine, batch): x, y = batch y_pred = model(inputs) loss = loss_fn(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() return loss.item(), y_pred, y trainer = Engine(update) @trainer.on(Events.EPOCH_COMPLETED) def print_loss(engine): epoch = engine.state.epoch loss = engine.state.output[0] print ('Epoch {epoch}: train_loss = {loss}'.format(epoch=epoch, loss=loss)) accuracy = Accuracy(output_transform=lambda x: [x[1], x[2]]) accuracy.attach(trainer, 'acc') trainer.run(data, max_epochs=10)
这个例子中,在Engine中定义的处理函数的输出,也就是update函数的返回值为一个tuple:loss,y_pred, y
对比,看这个例子:
def update(engine, batch): x, y = batch y_pred = model(inputs) loss = loss_fn(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() return {'loss': loss.item(), 'y_pred': y_pred, 'y': y} trainer = Engine(update) @trainer.on(Events.EPOCH_COMPLETED) def print_loss(engine): epoch = engine.state.epoch loss = engine.state.output['loss'] print ('Epoch {epoch}: train_loss = {loss}'.format(epoch=epoch, loss=loss)) accuracy = Accuracy(output_transform=lambda x: [x['y_pred'], x['y']]) accuracy.attach(trainer, 'acc') trainer.run(data, max_epochs=10)
在Engine中定义的处理函数的输出,也就是update函数的返回值为一个字典:loss,y_pred,y。所以其他地方访问engine.state.output中的数据时,需要按照字典的方式。