老程序员分享:Pytorch入门之Siamese网络

简介: 老程序员分享:Pytorch入门之Siamese网络

首次体验Pytorch,本文参考于:github and PyTorch 中文网人脸相似度对比


本文主要熟悉Pytorch大致流程,修改了读取数据部分。没有采用原作者的ImageFolder方法: ImageFolder(root, transform=None, target_transform=None, loader=default_loader)。而是采用了一种更自由的方法,利用了Dataset 和 DataLoader 自由实现,更加适合于不同数据的预处理导入工作。


Siamese网络不用多说,就是两个共享参数的CNN。每次的输入是一对图像+1个label,共3个值。注意label=0或1(又称正负样本),表示输入的两张图片match(匹配、同一个人)或no-match(不匹配、非同一人)。 下图是Siamese基本结构,图是其他论文随便找的,输入看做两张图片就好。只不过下图是两个光普段而已。


1. 数据处理


数据采用的是AT&T人脸数据。共40个人,每个人有10张脸。数据下载:AT&T


首先解压后发现文件夹下共40个文件夹,每个文件夹里有10张pgm图片。这里生成一个包含图片路径的train.txt文件共后续调用:


def convert(train=True):


if(train):


f=open(Config.txt_root, 'w')


data_path=root+'/train/'


if(not os.path.exists(data_path)):


os.makedirs(data_path)


for i in range(40):


for j in range(10):


img_path = data_path+'s'+str(i+1)+'/'+str(j+1)+'.pgm'


f.write(img_path+' '+str(i)+'\n')


f.close()


生成结果:每行前面为每张图片的完整路径, 后面数字为类别标签0~39。train文件夹下为s1~s40共40个子文件夹。


2. 定制个性化数据集


这一步骤主要继承了类Dataset,然后重写getitem和len方法即可:


class MyDataset(Dataset): # 集成Dataset类以定制


def init(self, txt, transform=None, target_transform=None, should_invert=False):


self.transform = transform


self.target_transform = target_transform


self.should_invert = should_invert


self.txt = txt # 之前生成的train.txt


def getitem(self, index):


line = linecache.getline(self.txt, random.randint(1, self.len())) # 随机选择一个人脸


line.strip('\n')


img0_list= line.split()


should_get_same_class = random.randint(0,1) # 随机数0或1,是否选择同一个人的脸,这里为了保证尽量使匹配和非匹配数据大致平衡(正负类样本相当)


if should_get_same_class: # 执行的话就挑一张同一个人的脸作为匹配样本对


while True:


img1_list = linecache.getline(self.txt, random.randint(1, self.len())).strip('\n').split()


if img0_list【1】==img1_list【1】:


break


else: # else就是随意挑一个人的脸作为非匹配样本对,当然也可能抽到同一个人的脸,概率较小而已


img1_list = linecache.getline(self.txt, random.randint(1,self.len())).strip('\n').split()


img0 = Image.open(img0_list【0】) # img_list都是大小为2的列表,list【0】为图像, list【1】为label


img1 = Image.open(img1_list【0】)


img0 = img0.convert("L") # 转为灰度


img1 = img1.convert("L")


if self.should_invert: # 是否进行像素反转操作,即0变1,1变0


img0 = PIL.ImageOps.invert(img0)


img1 = PIL.ImageOps.invert(img1)


if self.transform is not None: # 非常方便的transform操作,在实例化时可以进行任意定制


img0 = self.transform(img0)


img1 = self.transform(img1)


return img0, img1 , torch.from_numpy(np.array(【int(img1_list【1】!=img0_list【1】)】,dtype=np.float32)) # 注意一定要返回数据+标签, 这里返回一对图像+label(应由numpy转为tensor)


def len(self): # 数据总长


fh = open(self.txt, 'r')


num = len(fh.readlines())


fh.close()


return num


3. 制作双塔CNN


class SiameseNetwork(nn.Module):


def init(self):


super(SiameseNetwork, self).init()


self.cnn1 = nn.Sequential(


nn.ReflectionPad2d(1),


nn.Conv2d(1, 4, kernel_size=3),


nn.ReLU(inplace=True),


nn.BatchNorm2d(4),


nn.Dropout2d(p=.2),


nn.ReflectionPad2d(1),


nn.Conv2d(4, 8, kernel_size=3),


nn.ReLU(inplace=True),


nn.BatchNorm2d(8),


nn.Dropout2d(p=.2),


nn.ReflectionPad2d(1),


nn.Conv2d(8, 8, kernel_size=3),


nn.ReLU(inplace=True),


nn.BatchNorm2d(8),


nn.Dropout2d(p=.2),


)


self.fc1 = nn.Sequential(


nn.Linear(8100100, 500),


nn.ReLU(inplace=True),


nn.Linear(500, 500),


nn.ReLU(inplace=True),


nn.Linear(500, 5)


)


def forward_once(self, x):


output = self.cnn1(x)


output = output.view(output.size()【0】, -1)


output = self.fc1(output)


return output


def forward(self, input1, input2):


output1 = self.forward_once(input1)


output2 = self.forward_once(input2)


return output1, output2


很简单,没说的,注意前向传播是两张图同时输入进行。


4. 定制对比损失函数


# Custom Contrastive Loss


class ContrastiveLoss(torch.nn.Module):


"""


Contrastive loss function.


Based on:


"""


def init(self, margin=2.0):


super(ContrastiveLoss, self).init()


self.margin = margin


def forward(self, output1, output2, label):


euclidean_distance = F.pairwise_distance(output1, output2)


loss_contrastive = torch.mean((1-label) torch.pow(euclidean_distance, 2) + # calmp夹断用法


(label) torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))


return loss_contrastive


上面的损失函数为自己制作的,公式源于lecun文章:


Loss =


DW=


m为容忍度, Dw为两张图片的欧氏距离。


5. 训练一波


train_data = MyDataset(txt = Config.txt_root,transform=transforms.Compose(


【transforms.Resize((100,100)),transforms.ToTensor()】), should_invert=False) #Resize到100,100


train_dataloader = DataLoader(dataset=train_data, shuffle=True, num_workers=2, batch_size = Config.train_batch_size)


net = SiameseNetwork().cuda() # GPU加速


criterion = ContrastiveLoss()


optimizer = optim.Adam(net.parameters(), lr=0.0005)


counter = 【】


loss_history =【】


iteration_number =0


for epoch in range(0, Config.train_number_epochs):


for i, data in enumerate(train_dataloader, 0):


img0, img1, label = data


img0, img1, label = Variable(img0).cuda(), Variable(img1).cuda(), Variable(label).cuda()


output1, output2 = net(img0, img1)


optimizer.zero_grad()


loss_contrastive = criterion(output1, output2, label)


loss_contrastive.backward()


optimizer.step()


if i%10 == 0:


print("Epoch:{}, Current loss {}\n".format(epoch,loss_contrastive.data【0】))


iteration_number += 10


//代码效果参考:http://www.jhylw.com.cn/583124999.html

counter.append(iteration_number)

loss_history.append(loss_contrastive.data【0】)


show_plot(counter, loss_history) # plot 损失函数变化曲线


损失函数结果图:


batch_size=32, epoches=20, lr=0.001 batch_size=32, epoches=30, lr=0.0005


全部代码:


#!/usr/bin/env python3


# -- coding: utf-8 --


"""


Created on Wed Jan 24 10:00:24 2018


Paper: Siamese Neural Networks for One-shot Image Recognition


links:


"""


import torch


from torch.autograd import Variable


import os


import random


import linecache


import numpy as np


import //代码效果参考:http://www.jhylw.com.cn/055937059.html

torchvision

from torch.utils.data import Dataset, DataLoader


from torchvision import transforms


from PIL import Image


import PIL.ImageOps


import matplotlib.pyplot as plt


class Config():


root = '/home/lps/Spyder/data_faces/'


txt_root = '/home/lps/Spyder/data_faces/train.txt'


train_batch_size = 32


train_number_epochs = 30


# Helper functions


def imshow(img,text=None,should_save=False):


npimg = img.numpy()


plt.axis("off")


if text:


plt.text(75, 8, text, style='italic',fontweight='bold',


bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})


plt.imshow(np.transpose(npimg, (1, 2, 0)))


plt.show()


def show_plot(iteration,loss):


plt.plot(iteration,loss)


plt.show()


def convert(train=True):


if(train):


f=open(Config.txt_root, 'w')


data_path=root+'/train/'


if(not os.path.exists(data_path)):


os.makedirs(data_path)


for i in range(40):


for j in range(10):


img_path = data_path+'s'+str(i+1)+'/'+str(j+1)+'.pgm'


f.write(img_path+' '+str(i)+'\n')


f.close()


#convert(True)


# ready the dataset, Not use ImageFolder as the author did


class MyDataset(Dataset):


def init(self, txt, transform=None, target_transform=None, should_invert=False):


self.transform = transform


self.target_transform = target_transform


self.should_invert = should_invert


self.txt = txt


def getitem(self, index):


line = linecache.getline(self.txt, random.randint(1, self.len()))


line.strip('\n')


img0_list= line.split()


should_get_same_class = random.randint(0,1)


if should_get_same_class:


while True:


img1_list = linecache.getline(self.txt, random.randint(1, self.len())).strip('\n').split()


if img0_list【1】==img1_list【1】:


break


else:


img1_list = linecache.getline(self.txt, random.randint(1,self.len())).strip('\n').split()


img0 = Image.open(img0_list【0】)


img1 = Image.open(img1_list【0】)


img0 = img0.convert("L")


img1 = img1.convert("L")


if self.should_invert:


img0 = PIL.ImageOps.invert(img0)


img1 = PIL.ImageOps.invert(img1)


if self.transform is not None:


img0 = self.transform(img0)


img1 = self.transform(img1)


目录
打赏
0
0
0
0
33
分享
相关文章
基于Pytorch 在昇腾上实现GCN图神经网络
本文详细讲解了如何在昇腾平台上使用PyTorch实现图神经网络(GCN)对Cora数据集进行分类训练。内容涵盖GCN背景、模型特点、网络架构剖析及实战分析。GCN通过聚合邻居节点信息实现“卷积”操作,适用于非欧氏结构数据。文章以两层GCN模型为例,结合Cora数据集(2708篇科学出版物,1433个特征,7种类别),展示了从数据加载到模型训练的完整流程。实验在NPU上运行,设置200个epoch,最终测试准确率达0.8040,内存占用约167M。
基于Pytorch 在昇腾上实现GCN图神经网络
Perforated Backpropagation:神经网络优化的创新技术及PyTorch使用指南
深度学习近年来在多个领域取得了显著进展,但其核心组件——人工神经元和反向传播算法自提出以来鲜有根本性突破。穿孔反向传播(Perforated Backpropagation)技术通过引入“树突”机制,模仿生物神经元的计算能力,实现了对传统神经元的增强。该技术利用基于协方差的损失函数训练树突节点,使其能够识别神经元分类中的异常模式,从而提升整体网络性能。实验表明,该方法不仅可提高模型精度(如BERT模型准确率提升3%-17%),还能实现高效模型压缩(参数减少44%而无性能损失)。这一革新为深度学习的基础构建模块带来了新的可能性,尤其适用于边缘设备和大规模模型优化场景。
85 16
Perforated Backpropagation:神经网络优化的创新技术及PyTorch使用指南
基于昇腾用PyTorch实现CTR模型DIN(Deep interest Netwok)网络
本文详细讲解了如何在昇腾平台上使用PyTorch训练推荐系统中的经典模型DIN(Deep Interest Network)。主要内容包括:DIN网络的创新点与架构剖析、Activation Unit和Attention模块的实现、Amazon-book数据集的介绍与预处理、模型训练过程定义及性能评估。通过实战演示,利用Amazon-book数据集训练DIN模型,最终评估其点击率预测性能。文中还提供了代码示例,帮助读者更好地理解每个步骤的实现细节。
基于Pytorch Gemotric在昇腾上实现GAT图神经网络
本实验基于昇腾平台,使用PyTorch实现图神经网络GAT(Graph Attention Networks)在Pubmed数据集上的分类任务。内容涵盖GAT网络的创新点分析、图注意力机制原理、多头注意力机制详解以及模型代码实战。实验通过两层GAT网络对Pubmed数据集进行训练,验证模型性能,并展示NPU上的内存使用情况。最终,模型在测试集上达到约36.60%的准确率。
PyTorch 实现FCN网络用于图像语义分割
本文详细讲解了在昇腾平台上使用PyTorch实现FCN(Fully Convolutional Networks)网络在VOC2012数据集上的训练过程。内容涵盖FCN的创新点分析、网络架构解析、代码实现以及端到端训练流程。重点包括全卷积结构替换全连接层、多尺度特征融合、跳跃连接和反卷积操作等技术细节。通过定义VOCSegDataset类处理数据集,构建FCN8s模型并完成训练与测试。实验结果展示了模型在图像分割任务中的应用效果,同时提供了内存使用优化的参考。
基于Pytorch Gemotric在昇腾上实现GraphSage图神经网络
本实验基于PyTorch Geometric,在昇腾平台上实现GraphSAGE图神经网络,使用CiteSeer数据集进行分类训练。内容涵盖GraphSAGE的创新点、算法原理、网络架构及实战分析。GraphSAGE通过采样和聚合节点邻居特征,支持归纳式学习,适用于未见节点的表征生成。实验包括模型搭建、训练与验证,并在NPU上运行,最终测试准确率达0.665。
PyTorch生态系统中的连续深度学习:使用Torchdyn实现连续时间神经网络
神经常微分方程(Neural ODEs)是深度学习领域的创新模型,将神经网络的离散变换扩展为连续时间动力系统。本文基于Torchdyn库介绍Neural ODE的实现与训练方法,涵盖数据集构建、模型构建、基于PyTorch Lightning的训练及实验结果可视化等内容。Torchdyn支持多种数值求解算法和高级特性,适用于生成模型、时间序列分析等领域。
294 77
PyTorch生态系统中的连续深度学习:使用Torchdyn实现连续时间神经网络
基于昇腾用PyTorch实现传统CTR模型WideDeep网络
本文介绍了如何在昇腾平台上使用PyTorch实现经典的WideDeep网络模型,以处理推荐系统中的点击率(CTR)预测问题。
363 66
云栖大会 | Terraform从入门到实践:快速构建你的第一张业务网络
云栖大会 | Terraform从入门到实践:快速构建你的第一张业务网络
116 1
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
软演员-评论家算法(Soft Actor-Critic, SAC)是深度强化学习领域的重要进展,基于最大熵框架优化策略,在探索与利用之间实现动态平衡。SAC通过双Q网络设计和自适应温度参数,提升了训练稳定性和样本效率。本文详细解析了SAC的数学原理、网络架构及PyTorch实现,涵盖演员网络的动作采样与对数概率计算、评论家网络的Q值估计及其损失函数,并介绍了完整的SAC智能体实现流程。SAC在连续动作空间中表现出色,具有高样本效率和稳定的训练过程,适合实际应用场景。
1475 7
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现

热门文章

最新文章

推荐镜像

更多
AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等

登录插画

登录以查看您的控制台资源

管理云资源
状态一览
快捷访问