1. 可复现性配置
import torch import random import numpy as np myseed = 12345 torch.manual_seed(myseed) random.seed(0) np.random.seed(myseed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.cuda.manual_seed_all(myseed)
2. 模型搭建
import torch class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() #一般会在这里放网络层和其他后续会用到的全局超参 def forward(self, x): #用__init__中的Module来搭建网络 #(在这里也可以新加层,如放激活函数等) #返回输出。 model=Net()
2.1 激活函数
常用激活函数:Sigmoid,ReLU,tanh
3. 数据集
Dataset
from torch.utils.data import Dataset, DataLoader class MyDataset(Dataset): def __init__(self, file): #可以在这个部分根据mode(train/val/test)入参来对数据集进行划分 self.data = ... def __getitem__(self, index): return self.data[index] def __len__(self): return len(self.data)
DataLoader
dataset = MyDataset(file) dataloader = DataLoader(dataset, batch_size, shuffle=True) #shuffle:训练时置True,测试时置False
3. 训练→验证→测试,模型保存
3.1 训练
训练模型
#把模型放到GPU上(此处仅考虑了单卡的情况,多卡情况可参考我之前撰写的这篇博文:https://blog.csdn.net/PolarisRisingWar/article/details/116069338 其中介绍了一些使用torch.nn.DataParallel的方法。更多分布式训练的方法我会在后续博文中陆续撰写) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model.to(device) #定义优化器和损失函数 optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9) criterion = torch.nn.MSELoss() #注意:有些模型可能需要自定义损失函数(如PTA模型) #PTA模型等的做法是在模型中自定义了loss_function函数,返回仅含loss这一元素的Tensor #Optional: pre-processing工作(如PTA模型,开始训练前先跑了一波标签传播算法) model.train() for epoch in range(epochs): for batch_x,batch_y in train_data: #有些模型不需要分批次训练,比如很多GNN模型就是一波流…… batch_x,batch_y=batch_x.to(device),batch_y.to(device) prediction = model(batch_x) loss = criterion(prediction, batch_y) loss.backward() optimizer.step() optimizer.zero_grad() #Optinal: post-processing工作(如C&S模型和PTA模型) #验证 #存储在验证集上表现最好的模型 #设置early stopping:如果验证集上的表现在超过某个阈值的次数内仍然没有变好,就可以停了 #关于这几个步骤的顺序:计算loss→backward→step有严格的先后顺序,zero_grad加在网络还是优化器上看需求(一般都是在优化器上),zero_grad在backward之前或者step之后都可以(一般都是step之后,也就是一个epoch运算结束之后)
储存模型
PATH = '.model.pth' torch.save(model.state_dict(), PATH)
3.2 加载模型,验证和测试
加载储存在本地的模型
model=Net() model.load_state_dict(torch.load(PATH))
验证和测试
model.eval() with torch.no_grad(): for batch_x in test_data: batch_x.to(device) prediction = model(batch_x) #在验证集上也可以计算损失函数:loss = criterion(prediction, batch_y)
5. 辅助可视化
5.1 绘制沿epoch的loss变化曲线图(在训练或验证时储存记录)
def plot_learning_curve(loss_record, title=''): ''' Plot learning curve of your DNN (train & dev loss) ''' total_steps = len(loss_record['train']) x_1 = range(total_steps) x_2 = x_1[::len(loss_record['train']) // len(loss_record['dev'])] figure(figsize=(6, 4)) plt.plot(x_1, loss_record['train'], c='tab:red', label='train') plt.plot(x_2, loss_record['dev'], c='tab:cyan', label='dev') plt.ylim(0.0, 5.) plt.xlabel('Training steps') plt.ylabel('MSE loss') plt.title('Learning curve of {}'.format(title)) plt.legend() plt.show()
示例图:
5.2 绘制沿epoch的loss和ACC变化曲线图(在训练或验证时储存记录)
plt.title(dataset_name+'数据集在'+model_name+'模型上的loss') plt.plot(train_losses, label="training loss") plt.plot(val_losses, label="validating loss") plt.plot(test_losses, label="testing loss") plt.legend() plt.savefig(pics_root+'/loss_'+pics_name) plt.close() #为了防止多图冲突 plt.title(dataset_name+'数据集在'+model_name+'模型上的ACC',fontproperties=font) plt.plot(train_accs, label="training acc") plt.plot(val_accs, label="validating acc") plt.plot(test_accs, label="testing acc") plt.legend() plt.savefig(pics_root+'/acc_'+pics_name) plt.close()
5.3 绘制预测结果
def plot_pred(dv_set, model, device, lim=35., preds=None, targets=None): ''' Plot prediction of your DNN ''' if preds is None or targets is None: model.eval() preds, targets = [], [] for x, y in dv_set: x, y = x.to(device), y.to(device) with torch.no_grad(): pred = model(x) preds.append(pred.detach().cpu()) targets.append(y.detach().cpu()) preds = torch.cat(preds, dim=0).numpy() targets = torch.cat(targets, dim=0).numpy() figure(figsize=(5, 5)) plt.scatter(targets, preds, c='r', alpha=0.5) plt.plot([-0.2, lim], [-0.2, lim], c='b') plt.xlim(-0.2, lim) plt.ylim(-0.2, lim) plt.xlabel('ground truth value') plt.ylabel('predicted value') plt.title('Ground Truth v.s. Prediction') plt.show()
示例图: