数据集
该图像数据集包含8000张图像,两个类别分别是安全帽与人、以其中200多张图像为验证集,其余为训练集。
模型训练
准备好数据集以后,直接按下面的命令行运行即可:
yolo train model=yolov8s.pt data=hat_dataset.yaml epochs=50 imgsz=640 batch=4
导出与测试
下面的命令行,导出模型为ONNX格式,同时预测模型的实际推理能力
yolo export model=hat_best.pt format=onnx yolo predict model=hat_best.pt source=./hats
### 部署推理
转成ONNX格式文件以后,基于OpenVINO-Python部署推理,相关代码如下
# Read IR model = ie.read_model(model="hat_best.onnx") compiled_model = ie.compile_model(model=model, device_name="CPU") output_layer = compiled_model.output(0) capture = cv.VideoCapture("D:/images/video/hat_test.mp4") while True: _, frame = capture.read() if frame is None: print("End of stream") break bgr = format_yolov8(frame) img_h, img_w, img_c = bgr.shape start = time.time() image = cv.dnn.blobFromImage(bgr, 1 / 255.0, (640, 640), swapRB=True, crop=False) res = compiled_model([image])[output_layer] # 1x84x8400 rows = np.squeeze(res, 0).T class_ids = [] confidences = [] boxes = [] x_factor = img_w / 640 y_factor = img_h / 640 for r in range(rows.shape[0]): row = rows[r] classes_scores = row[4:] _, _, _, max_indx = cv.minMaxLoc(classes_scores) class_id = max_indx[1] if (classes_scores[class_id] > .25): confidences.append(classes_scores[class_id]) class_ids.append(class_id) x, y, w, h = row[0].item(), row[1].item(), row[2].item(), row[3].item() left = int((x - 0.5 * w) * x_factor) top = int((y - 0.5 * h) * y_factor) width = int(w * x_factor) height = int(h * y_factor) box = np.array([left, top, width, height]) boxes.append(box) indexes = cv.dnn.NMSBoxes(boxes, confidences, 0.25, 0.45) for index in indexes: box = boxes[index] color = colors[int(class_ids[index]) % len(colors)] cv.rectangle(frame, box, color, 2) cv.rectangle(frame, (box[0], box[1] - 20), (box[0] + box[2], box[1]), color, -1) cv.putText(frame, class_list[class_ids[index]], (box[0], box[1] - 10), cv.FONT_HERSHEY_SIMPLEX, .5, (0, 0, 0)) end = time.time() inf_end = end - start fps = 1 / inf_end fps_label = "FPS: %.2f" % fps cv.putText(frame, fps_label, (20, 45), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) cv.imshow("YOLOv8 hat Detection", frame) cc = cv.waitKey(1) if cc == 27: break cv.destroyAllWindows()
认真学习 YOLOv8 点这里。