1.损失函数知识总结参考:
深度学习笔记总结_GoAI的博客-CSDN博客
PyTorch 笔记.常见的PyTorch损失函数 - 知乎
Pytorch神经网络实战学习笔记_10 神经网络模块中的损失函数_LiBiGor的博客-CSDN博客
2.自定义损失函数学习参考:
pytorch教程之nn.Module类详解——使用Module类来自定义模型
pytorch教程之nn.Module类详解——使用Module类来自定义网络层
pytorch教程之损失函数详解——多种定义损失函数的方法
Loss Function Library - Keras & PyTorch | Kaggle
Pytorch如何自定义损失函数(Loss Function)? - 知乎
pytorch系列12 --pytorch自定义损失函数custom loss function_墨流觞的博客-
自定义损失函数 - image processing
pytorch教程之损失函数详解——多种定义损失函数的方法
Pytorch自定义网络结构+读取自己数据+自定义Loss 全过程代码示例
3.定义原始模版:
使用torch.Tensor提供的接口实现:
继承nn.Module类
在__init__函数中定义所需要的超参数,在foward函数中定义loss的计算方法。
所有的数学操作使用tensor提供的math operation
返回的tensor是0-dim的scalar
有可能会用到nn.functional中的一些操作
Pytorch如何自定义损失函数(Loss Function)? - 知乎
#例子: class myLoss(nn.Module): def __init__(self,parameters) self.params = self.parameters def forward(self) loss = cal_loss(self.params) return loss #使用 criterion=myLoss() loss=criterion(……)
4.自定义函数方法
方法一:新建一个类
方案1:只定义loss函数的前向计算公式
在pytorch中定义了前向计算的公式,在训练时它会自动帮你计算反向传播。
class My_loss(nn.Module): def __init__(self): super().__init__() def forward(self, x, y): return torch.mean(torch.pow((x - y), 2)) #使用: criterion = My_loss() loss = criterion(outputs, targets)
方案2:自定义loss函数的forward和backward
from numpy.fft import rfft2, irfft2 class BadFFTFunction(Function): def forward(self, input): numpy_input = input.numpy() result = abs(rfft2(numpy_input)) return input.new(result) def backward(self, grad_output): numpy_go = grad_output.numpy() result = irfft2(numpy_go) return grad_output.new(result)
Pytorch完整训练流程
1.限定使用GPU的序号
import os os.environ['CUDA_VISIBLE_DEVICES'] = '3' os.system('echo $CUDA_VISIBLE_DEVICES')
2、导入相关库
import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.data as data import torch.optim as optim from torch.autograd import Variable import numpy as np from Encoding import load_feature
3、自定义网络
class TransientModel(nn.Module): def __init__(self): super(TransientModel,self).__init__() self.conv1 = nn.Conv2d(16, 8, kernel_size=1) self.conv2 = nn.Conv2d(8, 4, kernel_size=1) self.conv3 = nn.Conv2d(4, 2, kernel_size=1) self.conv4 = nn.Conv2d(2, 1, kernel_size=1) def forward(self, x): x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = F.relu(self.conv4(x)) return x
4、自定义损失函数Loss
class MyLoss(nn.Module): def __init__(self): super(MyLoss, self).__init__() print '1' def forward(self, pred, truth): truth = torch.mean(truth,1) truth = truth.view(-1,2048) pred = pred.view(-1,2048) return torch.mean(torch.mean((pred-truth)**2,1),0)
5、自定义数据读取
class MyTrainData(data.Dataset): def __init__(self): self.video_path = '/data/FrameFeature/Penn/' self.video_file = '/data/FrameFeature/Penn_train.txt' fp = open(self.video_file,'r') lines = fp.readlines() fp.close() self.video_name = [] for line in lines: self.video_name.append(line.strip().split(' ')[0]) def __len__(self): return len(self.video_name) def __getitem__(self, index): data = load_feature(os.path.join(self.video_path,self.video_name[index])) data = np.expand_dims(data,2) return data
6、定义Train函数
def train(model, train_loader, myloss, optimizer, epoch): model.train() for batch_idx, train_data in enumerate(train_loader): train_data = Variable(train_data).cuda() optimizer.zero_grad() output = model(train_data) loss = myloss(output, train_data) loss.backward() optimizer.step() if batch_idx%100 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tloss: {:.6f}'.format( epoch, batch_idx*len(train_data), len(train_loader.dataset), 100.*batch_idx/len(train_loader), loss.data.cpu().numpy()[0]))
7.训练
if __name__=='__main__': # main() model = TransientModel().cuda() myloss= MyLoss() train_data = MyTrainData() train_loader = data.DataLoader(train_data,batch_size=1,shuffle=True,num_workers=1) optimizer = optim.SGD(model.parameters(),lr=0.001) for epoch in range(10): train(model, train_loader, myloss, optimizer, epoch)
8、结果展示
完整代码
import os os.environ['CUDA_VISIBLE_DEVICES'] = '3' os.system('echo $CUDA_VISIBLE_DEVICES') import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.data as data import torch.optim as optim from torch.autograd import Variable import numpy as np from Encoding import load_feature class TransientModel(nn.Module): def __init__(self): super(TransientModel,self).__init__() self.conv1 = nn.Conv2d(16, 8, kernel_size=1) self.conv2 = nn.Conv2d(8, 4, kernel_size=1) self.conv3 = nn.Conv2d(4, 2, kernel_size=1) self.conv4 = nn.Conv2d(2, 1, kernel_size=1) def forward(self, x): x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = F.relu(self.conv4(x)) return x class MyLoss(nn.Module): def __init__(self): super(MyLoss, self).__init__() print '1' def forward(self, pred, truth): truth = torch.mean(truth,1) truth = truth.view(-1,2048) pred = pred.view(-1,2048) return torch.mean(torch.mean((pred-truth)**2,1),0) class MyTrainData(data.Dataset): def __init__(self): self.video_path = '/data/FrameFeature/Penn/' self.video_file = '/data/FrameFeature/Penn_train.txt' fp = open(self.video_file,'r') lines = fp.readlines() fp.close() self.video_name = [] for line in lines: self.video_name.append(line.strip().split(' ')[0]) def __len__(self): return len(self.video_name) def __getitem__(self, index): data = load_feature(os.path.join(self.video_path,self.video_name[index])) data = np.expand_dims(data,2) return data def train(model, train_loader, myloss, optimizer, epoch): model.train() for batch_idx, train_data in enumerate(train_loader): train_data = Variable(train_data).cuda() optimizer.zero_grad() output = model(train_data) loss = myloss(output, train_data) loss.backward() optimizer.step() if batch_idx%100 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tloss: {:.6f}'.format( epoch, batch_idx*len(train_data), len(train_loader.dataset), 100.*batch_idx/len(train_loader), loss.data.cpu().numpy()[0])) def main(): model = TransientModel().cuda() myloss= MyLoss() train_data = MyTrainData() train_loader = data.DataLoader(train_data,batch_size=1,shuffle=True,num_workers=1) optimizer = optim.SGD(model.parameters(),lr=0.001) for epoch in range(10): train(model, train_loader, myloss, optimizer, epoch) if __name__=='__main__': main()