由于这些数据集是通过爬虫直接在网上下载的,有很多错误的图,需要把它们找出来进行删除。
1:使用RetinaFace对给定的图片做人脸检测,对于提取不到landmark/boundingbox的图片逐个做分析。
2:将图像缩小送进网络训练发现很多都可以检测出来了,其实retinaface对尺度比较小的图片效果会好很多
https://download.csdn.net/download/m0_51004308/19513502?spm=1001.2014.3001.5501这是基于python版本的整个完整的数据集清洗(包括retainface模型文件,运行操作也非常简单,只需要把你的数据集路径加入其中就行。有详细注释,需要可以了解)
import numpy as np
import torch
import cv2
from retinaface.detector import RetinafaceDetector, RetinafaceDetector_dnn
from align_faces import align_process
class retinaface():
def __init__(self, device = 'cuda', align=False):
self.retinaface = RetinafaceDetector(device=device)
self.align = align
def detect(self, srcimg):
bounding_boxes, landmarks = self.retinaface.detect_faces(srcimg)
drawimg, face_rois = srcimg.copy(), []
for i in range(bounding_boxes.shape[0]):
# score = bounding_boxes[i,4]
x1, y1, x2, y2 = (bounding_boxes[i, :4]).astype(np.int32)
cv2.rectangle(drawimg, (x1, y1), (x2, y2), (0, 0, 255), thickness=2)
face_roi = srcimg[y1:y2, x1:x2]
landmark = landmarks[i, :].reshape((2, 5)).T
if self.align:
face_roi = align_process(srcimg, bounding_boxes[i, :4], landmark, (224, 224))
landmark = landmark.astype(np.int32)
for j in range(5):
cv2.circle(drawimg, (landmark[j, 0], landmark[j, 1]), 2, (0, 255, 0), thickness=-1)
# cv2.putText(drawimg, str(j), (landmark[j, 0], landmark[j, 1] + 12), cv2.FONT_HERSHEY_DUPLEX, 1, (0, 0, 255))
face_rois.append(face_roi)
return drawimg, face_rois
def get_face(self, srcimg):
bounding_boxes, landmarks = self.retinaface.detect_faces(srcimg)
boxs, face_rois = [], []
for i in range(bounding_boxes.shape[0]):
# score = bounding_boxes[i,4]
box = (bounding_boxes[i, :4]).astype(np.int32).tolist()
face_roi = srcimg[box[1]:box[3], box[0]:box[2]]
landmark = landmarks[i, :].reshape((2, 5)).T
if self.align:
face_roi = align_process(srcimg, bounding_boxes[i, :4], landmark, (224, 224))
box.extend(landmark.astype(np.int32).ravel().tolist())
boxs.append(tuple(box))
face_rois.append(face_roi)
return boxs, face_rois
class retinaface_dnn():
def __init__(self, align=False):
self.net = RetinafaceDetector_dnn()
self.align = align
def detect(self, srcimg):
bounding_boxes, landmarks = self.net.detect_faces(srcimg)
drawimg, face_rois = srcimg.copy(), []
for i in range(bounding_boxes.shape[0]):
# score = bounding_boxes[i,4]
x1, y1, x2, y2 = (bounding_boxes[i, :4]).astype(np.int32)
cv2.rectangle(drawimg, (x1, y1), (x2, y2), (0, 0, 255), thickness=2)
face_roi = srcimg[y1:y2, x1:x2]
landmark = landmarks[i, :].reshape((2, 5)).T
if self.align:
face_roi = align_process(srcimg, bounding_boxes[i, :4], landmark, (224, 224))
landmark = landmark.astype(np.int32)
for j in range(5):
cv2.circle(drawimg, (landmark[j, 0], landmark[j, 1]), 2, (0, 255, 0), thickness=-1)
# cv2.putText(drawimg, str(j), (landmark[j, 0], landmark[j, 1] + 12), cv2.FONT_HERSHEY_DUPLEX, 1, (0, 0, 255))
face_rois.append(face_roi)
return drawimg, face_rois
def get_face(self, srcimg):
bounding_boxes, landmarks = self.net.detect_faces(srcimg)
boxs, face_rois = [], []
for i in range(bounding_boxes.shape[0]):
# score = bounding_boxes[i,4]
box = (bounding_boxes[i, :4]).astype(np.int32).tolist()
face_roi = srcimg[box[1]:box[3], box[0]:box[2]]
landmark = landmarks[i, :].reshape((2, 5)).T
if self.align:
face_roi = align_process(srcimg, bounding_boxes[i, :4], landmark, (224, 224))
box.extend(landmark.astype(np.int32).ravel().tolist())
boxs.append(tuple(box))
face_rois.append(face_roi)
return boxs, face_rois
if __name__ == "__main__":
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# retinaface_detect = retinaface(device=device, align=True)
retinaface_detect = retinaface_dnn(align=True)
###dnn版本和pytorch版本的一个区别是: pytorch版本的输入图片不做resize就进入到网络里,而dnn版本的输入图片要resize到固定尺寸的,
###输入不同,因此对这两个版本的输出不做比较
import os,sys
f = open('error.txt', 'w')
imgroot = '/home/lqs/Documents/arcface-pytorch-master/data/CASIA-WebFace/'
#imgroot = '/home/lqs/Documents/arcface-pytorch-master/data/902_Pic_Fea/'
im=0
dirlist = os.listdir(imgroot) ### imgroot里有多个文件夹,每个文件夹存放着一个人物的多个肖像照,文件夹名称是人名
for i, name in enumerate(dirlist):
sys.stdout.write("\rRun person{0}-{1}, name:{2}\n".format(i,len(dirlist),name))
sys.stdout.flush()
imgdir = os.path.join(imgroot, name)
imglist = os.listdir(imgdir)
for imgname in imglist:
imapath=os.path.join(imgdir, imgname)
srcimg = cv2.imread(imapath)
drawimg, face_rois = retinaface_detect.detect(srcimg)
# cv2.imshow('face_rois',face_rois)
# cv2.waitKey(0)
#print(len(face_rois))
if len(face_rois)==0:
f.write(imapath + '\n')
im+=1
print('the error pic',imapath)
print('the totle error pic {} 个'.format(im))
f = open("/home/lqs/Documents/arcface-pytorch-master/data/error.txt", "r")
lines = f.readlines() # 读取全部内容
print('the len of txt is {}'.format(len(lines)))