6、ReID 的技术展望
第一个,ReID的数据比较难获取,如果用应用无监督学习去提高ReID效果,可以降低数据采集的依赖性,这也是一个研究方向。右边可以看到,GAN生成数据来帮助ReID数据增强,现在也是一个很大的分支,但这只是应用无监督学习的一个方向。
第二个,基于视频的ReID。因为刚才几个数据集是基于对视频切好的单个图片而已,但实际应用场景中还存在着视频的连续帧,连续帧可以获取更多信息,跟实际应用更贴近,很多研究者也在进行基于视频 ReID 的技术。
第三个,跨模态的ReID。刚才讲到白天和黑夜的问题,黑夜时可以用红外的摄像头拍出来的跟白色采样摄像头做匹配。
第四个,跨场景的迁移学习。就是在一个场景比如market1501上学到的ReID,怎样在Duke数据集上提高效果。
第五个,应用系统设计。相当于设计一套系统让ReID这个技术实际应用到行人检索等技术上去。
7、基于MGN-ReID方法项目实践
本项目基于以上说明的论文进行实践,数据集时Market1501数据集。实践的Baseline网络为ResNet50模型。
7.1、数据集处理和Dateloder输出
1、M数据集的处理:
from data.common import list_pictures from torch.utils.data import dataset from torchvision.datasets.folder import default_loader class Market1501(dataset.Dataset): def __init__(self, args, transform, dtype): self.transform = transform self.loader = default_loader data_path = args.datadir if dtype == 'train': data_path += '/bounding_box_train' elif dtype == 'test': data_path += '/bounding_box_test' else: data_path += '/query' self.imgs = [path for path in list_pictures(data_path) if self.id(path) != -1] self._id2label = {_id: idx for idx, _id in enumerate(self.unique_ids)} def __getitem__(self, index): path = self.imgs[index] target = self._id2label[self.id(path)] img = self.loader(path) if self.transform is not None: img = self.transform(img) return img, target def __len__(self): return len(self.imgs) @staticmethod def id(file_path): """ :param file_path: unix style file path :return: person id """ return int(file_path.split('/')[-1].split('_')[0]) @staticmethod def camera(file_path): """ :param file_path: unix style file path :return: camera id """ return int(file_path.split('/')[-1].split('_')[1][1]) @property def ids(self): """ :return: person id list corresponding to dataset image paths """ return [self.id(path) for path in self.imgs] @property def unique_ids(self): """ :return: unique person ids in ascending order """ return sorted(set(self.ids)) @property def cameras(self): """ :return: camera id list corresponding to dataset image paths """ return [self.camera(path) for path in self.imgs]
2、DataLoder的制作
from importlib import import_module from torchvision import transforms from utils.random_erasing import RandomErasing from data.sampler import RandomSampler from torch.utils.data import dataloader class Data: def __init__(self, args): train_list = [ transforms.Resize((args.height, args.width), interpolation=3), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ] if args.random_erasing: train_list.append(RandomErasing(probability=args.probability, mean=[0.0, 0.0, 0.0])) train_transform = transforms.Compose(train_list) test_transform = transforms.Compose([ transforms.Resize((args.height, args.width), interpolation=3), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) if not args.test_only: module_train = import_module('data.' + args.data_train.lower()) self.trainset = getattr(module_train, args.data_train)(args, train_transform, 'train') self.train_loader = dataloader.DataLoader(self.trainset, sampler=RandomSampler(self.trainset,args.batchid,batch_image=args.batchimage), #shuffle=True, batch_size=args.batchid * args.batchimage, num_workers=args.nThread) else: self.train_loader = None if args.data_test in ['Market1501']: module = import_module('data.' + args.data_train.lower()) self.testset = getattr(module, args.data_test)(args, test_transform, 'test') self.queryset = getattr(module, args.data_test)(args, test_transform, 'query') else: raise Exception() self.test_loader = dataloader.DataLoader(self.testset, batch_size=args.batchtest, num_workers=args.nThread) self.query_loader = dataloader.DataLoader(self.queryset, batch_size=args.batchtest, num_workers=args.nThread)
7.2、 数据增强操作
1、随机擦除操作——Random Erasing
from __future__ import absolute_import from torchvision.transforms import * from PIL import Image import random import math import numpy as np import torch class RandomErasing(object): def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): self.probability = probability self.mean = mean self.sl = sl self.sh = sh self.r1 = r1 def __call__(self, img): if random.uniform(0, 1) > self.probability: return img for attempt in range(100): area = img.size()[1] * img.size()[2] target_area = random.uniform(self.sl, self.sh) * area aspect_ratio = random.uniform(self.r1, 1/self.r1) h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) if w < img.size()[2] and h < img.size()[1]: x1 = random.randint(0, img.size()[1] - h) y1 = random.randint(0, img.size()[2] - w) if img.size()[0] == 3: img[0, x1:x1+h, y1:y1+w] = self.mean[0] img[1, x1:x1+h, y1:y1+w] = self.mean[1] img[2, x1:x1+h, y1:y1+w] = self.mean[2] else: img[0, x1:x1+h, y1:y1+w] = self.mean[0] return img return img
2、其他torch自带数据处理操作
def __init__(self, args): train_list = [ transforms.Resize((args.height, args.width), interpolation=3), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ] if args.random_erasing: train_list.append(RandomErasing(probability=args.probability, mean=[0.0, 0.0, 0.0])) train_transform = transforms.Compose(train_list) test_transform = transforms.Compose([ transforms.Resize((args.height, args.width), interpolation=3), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])
7.3、TripletSemihard Loss与Triplet Loss
#!/usr/bin/env python # -*- coding: utf-8 -*- import torch from torch import nn from torch.nn import functional as F class TripletSemihardLoss(nn.Module): """ Shape: - Input: :math:`(N, C)` where `C = number of channels` - Target: :math:`(N)` - Output: scalar. """ def __init__(self, device, margin=0, size_average=True): super(TripletSemihardLoss, self).__init__() self.margin = margin self.size_average = size_average self.device = device def forward(self, input, target): y_true = target.int().unsqueeze(-1) same_id = torch.eq(y_true, y_true.t()).type_as(input) pos_mask = same_id neg_mask = 1 - same_id def _mask_max(input_tensor, mask, axis=None, keepdims=False): input_tensor = input_tensor - 1e6 * (1 - mask) _max, _idx = torch.max(input_tensor, dim=axis, keepdim=keepdims) return _max, _idx def _mask_min(input_tensor, mask, axis=None, keepdims=False): input_tensor = input_tensor + 1e6 * (1 - mask) _min, _idx = torch.min(input_tensor, dim=axis, keepdim=keepdims) return _min, _idx # output[i, j] = || feature[i, :] - feature[j, :] ||_2 dist_squared = torch.sum(input ** 2, dim=1, keepdim=True) + \ torch.sum(input.t() ** 2, dim=0, keepdim=True) - \ 2.0 * torch.matmul(input, input.t()) dist = dist_squared.clamp(min=1e-16).sqrt() pos_max, pos_idx = _mask_max(dist, pos_mask, axis=-1) neg_min, neg_idx = _mask_min(dist, neg_mask, axis=-1) # loss(x, y) = max(0, -y * (x1 - x2) + margin) y = torch.ones(same_id.size()[0]).to(self.device) return F.margin_ranking_loss(neg_min.float(), pos_max.float(), y, self.margin, self.size_average) class TripletLoss(nn.Module): """Triplet loss with hard positive/negative mining. Reference: Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. Args: margin (float): margin for triplet. """ def __init__(self, margin=0.3, mutual_flag = False): super(TripletLoss, self).__init__() self.margin = margin self.ranking_loss = nn.MarginRankingLoss(margin=margin) self.mutual = mutual_flag def forward(self, inputs, targets): """ Args: inputs: feature matrix with shape (batch_size, feat_dim) targets: ground truth labels with shape (num_classes) """ n = inputs.size(0) #inputs = 1. * inputs / (torch.norm(inputs, 2, dim=-1, keepdim=True).expand_as(inputs) + 1e-12) # Compute pairwise distance, replace by the official when merged dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) dist = dist + dist.t() dist.addmm_(1, -2, inputs, inputs.t()) dist = dist.clamp(min=1e-12).sqrt() # for numerical stability # For each anchor, find the hardest positive and negative mask = targets.expand(n, n).eq(targets.expand(n, n).t()) dist_ap, dist_an = [], [] for i in range(n): dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) dist_ap = torch.cat(dist_ap) dist_an = torch.cat(dist_an) # Compute ranking hinge loss y = torch.ones_like(dist_an) loss = self.ranking_loss(dist_an, dist_ap, y) if self.mutual: return loss, dist return loss
7.4、MGN网络模型
import copy import torch from torch import nn import torch.nn.functional as F from torchvision.models.resnet import resnet50, Bottleneck def make_model(args): return MGN(args) class MGN(nn.Module): def __init__(self, args): super(MGN, self).__init__() num_classes = args.num_classes resnet = resnet50(pretrained=True) self.backone = nn.Sequential( resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1, resnet.layer2, resnet.layer3[0], ) res_conv4 = nn.Sequential(*resnet.layer3[1:]) res_g_conv5 = resnet.layer4 res_p_conv5 = nn.Sequential( Bottleneck(1024, 512, downsample=nn.Sequential(nn.Conv2d(1024, 2048, 1, bias=False), nn.BatchNorm2d(2048))), Bottleneck(2048, 512), Bottleneck(2048, 512)) res_p_conv5.load_state_dict(resnet.layer4.state_dict()) self.p1 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_g_conv5)) self.p2 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5)) self.p3 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5)) if args.pool == 'max': pool2d = nn.MaxPool2d elif args.pool == 'avg': pool2d = nn.AvgPool2d else: raise Exception() self.maxpool_zg_p1 = pool2d(kernel_size=(12, 4)) self.maxpool_zg_p2 = pool2d(kernel_size=(24, 8)) self.maxpool_zg_p3 = pool2d(kernel_size=(24, 8)) self.maxpool_zp2 = pool2d(kernel_size=(12, 8)) self.maxpool_zp3 = pool2d(kernel_size=(8, 8)) reduction = nn.Sequential(nn.Conv2d(2048, args.feats, 1, bias=False), nn.BatchNorm2d(args.feats), nn.ReLU()) self._init_reduction(reduction) self.reduction_0 = copy.deepcopy(reduction) self.reduction_1 = copy.deepcopy(reduction) self.reduction_2 = copy.deepcopy(reduction) self.reduction_3 = copy.deepcopy(reduction) self.reduction_4 = copy.deepcopy(reduction) self.reduction_5 = copy.deepcopy(reduction) self.reduction_6 = copy.deepcopy(reduction) self.reduction_7 = copy.deepcopy(reduction) #self.fc_id_2048_0 = nn.Linear(2048, num_classes) self.fc_id_2048_0 = nn.Linear(args.feats, num_classes) self.fc_id_2048_1 = nn.Linear(args.feats, num_classes) self.fc_id_2048_2 = nn.Linear(args.feats, num_classes) self.fc_id_256_1_0 = nn.Linear(args.feats, num_classes) self.fc_id_256_1_1 = nn.Linear(args.feats, num_classes) self.fc_id_256_2_0 = nn.Linear(args.feats, num_classes) self.fc_id_256_2_1 = nn.Linear(args.feats, num_classes) self.fc_id_256_2_2 = nn.Linear(args.feats, num_classes) self._init_fc(self.fc_id_2048_0) self._init_fc(self.fc_id_2048_1) self._init_fc(self.fc_id_2048_2) self._init_fc(self.fc_id_256_1_0) self._init_fc(self.fc_id_256_1_1) self._init_fc(self.fc_id_256_2_0) self._init_fc(self.fc_id_256_2_1) self._init_fc(self.fc_id_256_2_2) @staticmethod def _init_reduction(reduction): # conv nn.init.kaiming_normal_(reduction[0].weight, mode='fan_in') #nn.init.constant_(reduction[0].bias, 0.) # bn nn.init.normal_(reduction[1].weight, mean=1., std=0.02) nn.init.constant_(reduction[1].bias, 0.) @staticmethod def _init_fc(fc): nn.init.kaiming_normal_(fc.weight, mode='fan_out') #nn.init.normal_(fc.weight, std=0.001) nn.init.constant_(fc.bias, 0.) def forward(self, x): x = self.backone(x) p1 = self.p1(x) p2 = self.p2(x) p3 = self.p3(x) zg_p1 = self.maxpool_zg_p1(p1) zg_p2 = self.maxpool_zg_p2(p2) zg_p3 = self.maxpool_zg_p3(p3) zp2 = self.maxpool_zp2(p2) z0_p2 = zp2[:, :, 0:1, :] z1_p2 = zp2[:, :, 1:2, :] zp3 = self.maxpool_zp3(p3) z0_p3 = zp3[:, :, 0:1, :] z1_p3 = zp3[:, :, 1:2, :] z2_p3 = zp3[:, :, 2:3, :] fg_p1 = self.reduction_0(zg_p1).squeeze(dim=3).squeeze(dim=2) fg_p2 = self.reduction_1(zg_p2).squeeze(dim=3).squeeze(dim=2) fg_p3 = self.reduction_2(zg_p3).squeeze(dim=3).squeeze(dim=2) f0_p2 = self.reduction_3(z0_p2).squeeze(dim=3).squeeze(dim=2) f1_p2 = self.reduction_4(z1_p2).squeeze(dim=3).squeeze(dim=2) f0_p3 = self.reduction_5(z0_p3).squeeze(dim=3).squeeze(dim=2) f1_p3 = self.reduction_6(z1_p3).squeeze(dim=3).squeeze(dim=2) f2_p3 = self.reduction_7(z2_p3).squeeze(dim=3).squeeze(dim=2) ''' l_p1 = self.fc_id_2048_0(zg_p1.squeeze(dim=3).squeeze(dim=2)) l_p2 = self.fc_id_2048_1(zg_p2.squeeze(dim=3).squeeze(dim=2)) l_p3 = self.fc_id_2048_2(zg_p3.squeeze(dim=3).squeeze(dim=2)) ''' l_p1 = self.fc_id_2048_0(fg_p1) l_p2 = self.fc_id_2048_1(fg_p2) l_p3 = self.fc_id_2048_2(fg_p3) l0_p2 = self.fc_id_256_1_0(f0_p2) l1_p2 = self.fc_id_256_1_1(f1_p2) l0_p3 = self.fc_id_256_2_0(f0_p3) l1_p3 = self.fc_id_256_2_1(f1_p3) l2_p3 = self.fc_id_256_2_2(f2_p3) predict = torch.cat([fg_p1, fg_p2, fg_p3, f0_p2, f1_p2, f0_p3, f1_p3, f2_p3], dim=1) return predict, fg_p1, fg_p2, fg_p3, l_p1, l_p2, l_p3, l0_p2, l1_p2, l0_p3, l1_p3, l2_p3
主函数:
import data import loss import torch import model from trainer import Trainer from option import args import utils.utility as utility ckpt = utility.checkpoint(args) loader = data.Data(args) model = model.Model(args, ckpt) loss = loss.Loss(args, ckpt) if not args.test_only else None trainer = Trainer(args, model, loss, loader, ckpt) n = 0 if __name__ == '__main__': while not trainer.terminate(): n += 1 trainer.train() if args.test_every!=0 and n%args.test_every==0: trainer.test()
7.5、MGN-ReID模型的训练过程与测试结果
参考: