AI计算机视觉笔记二十六:YOLOV8自训练关键点检测

简介: 本文档详细记录了使用YOLOv8训练关键点检测模型的过程。首先通过清华源安装YOLOv8,并验证安装。接着通过示例权重文件与测试图片`bus.jpg`演示预测流程。为准备训练数据,文档介绍了如何使用`labelme`标注工具进行关键点标注,并提供了一个Python脚本`labelme2yolo.py`将标注结果从JSON格式转换为YOLO所需的TXT格式。随后,通过Jupyter Notebook可视化标注结果确保准确性。最后,文档展示了如何组织数据集目录结构,并提供了训练与测试代码示例,包括配置文件`smoke.yaml`及训练脚本`train.py`,帮助读者完成自定义模型的训练与评估。

记录学习YOLOV8过程,自训练关键点检测模型。

清华源:-i https://mirror.baidu.com/pypi/simple

1、yolov8安装

git clone https://github.com/ultralytics/ultralytics
​
cd ultralytics
​
pip install -e .

安装成功后,使用命令 yolo 简单看下版本

(yolov8) llh@anhao:/$ yolo version
​
8.0.206

2、简单测试

下载权重文件

GitHub - ultralytics/ultralytics: NEW - YOLOv8 🚀 in PyTorch > ONNX > OpenVINO > CoreML > TFLite
image.png
直接点击下载。

获取测试图片的文件在ultralytics\assets目录有,使用的是的里面的bus.jpg测试。

使用 yolo 命令进行测试

yolo detect predict model=./yolov8n.pt source=./bus.jpg

image.png
输出在runs/detect/predict/目录下。
image.png

3、安装labelme

pip install labelme

直接在终端运行labelme打开软件
image.png
先取消“保存图片数据”(减少标注文件大小);在文件下
image.png
打开文件目录

image.png
接下来标注

先标注检测框,检测框用Create Rectangle(Ctrl+N)

填写类别名称

填写group_id,用于匹配后续标注的关键点,以当前画面中出现的顺序标注即可。
image.png
标注关键点,检测关键点用Create Point

按关键点顺序标注,如我们的顺序是head、tail,不可以错;

填写关键点名称,如这里是head;

填写关键点所在物体的group_id,匹配检测框

image.png
注意,如果多个类型需要填写group_id,group_id要匹配检测框

4、把JSON转成TXT

使用labelme标注生成的JSON文件,不能直接训练,所以需要把JSON文件转成TXT文件。

打开一个标注json文件,内容大致如下

version": "5.3.1",
  "flags": {},
  "shapes": [
    {
      "label": "smoke",
      "points": [
        [
          389.0,
          72.5
        ],
        [
          957.0,
          114.5
        ],
        [
          949.0,
          192.5
        ],
        [
          379.0,
          162.5
        ]
      ],
      "group_id": null,
      "description": "",
      "shape_type": "polygon",
      "flags": {}
    },

但是yolov8 要求的标注文件长这样

0 0.662641 0.494385 0.674719 0.988771 0.717187 0.189583 2.000000 0.798438 0.127083 2.000000 0.701562 0.091667 2.000000 0.921875 0.118750 2.000000 0.000000 0.000000 0.000000 0.971875 0.379167 2.000000 0.554688 0.262500 2.000000 0.000000 0.000000 0.000000 0.367188 0.427083 2.000000 0.767188 0.772917 2.000000 0.421875 0.500000 2.000000 0.829688 0.960417 1.000000 0.517188 0.881250 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000

根据这些规则,我们可以写一个转换脚本,将labelme标注的json格式转为yolo格式

labelme2yolo.py


# 将labelme标注的json文件转为yolo格式
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import glob
import json
import tqdm
​
​
# 物体类别
class_list = ["smoke"]
# 关键点的顺序
keypoint_list = ["head", "tail"]
​
​
def json_to_yolo(img_data,json_data):
    h,w = img_data.shape[:2]
    # 步骤:
    # 1. 找出所有的矩形,记录下矩形的坐标,以及对应group_id
    # 2. 遍历所有的head和tail,记下点的坐标,以及对应group_id,加入到对应的矩形中
    # 3. 转为yolo格式
​
    rectangles = {}
    # 遍历初始化
    for shape in json_data["shapes"]:
        label = shape["label"] # pen, head, tail
        group_id = shape["group_id"] # 0, 1, 2, ...
        points = shape["points"] # x,y coordinates
        shape_type = shape["shape_type"]
​
        # 只处理矩形
        if shape_type == "rectangle":
            if group_id not in rectangles:
                rectangles[group_id] = {
                    "label": label,
                    "rect": points[0] + points[1],  # Rectangle [x1, y1, x2, y2]
                    "keypoints_list": []
                }
    # 遍历更新,将点加入对应group_id的矩形中
    for keypoint in keypoint_list:
        for shape in json_data["shapes"]:
            label = shape["label"]
            group_id = shape["group_id"]
            points = shape["points"]
            # 如果匹配到了对应的keypoint
            if label == keypoint:
                rectangles[group_id]["keypoints_list"].append(points[0])

    # 转为yolo格式
    yolo_list = []
    for id, rectangle in rectangles.items():
        result_list  = []
        label_id = class_list.index(rectangle["label"])
        # x1,y1,x2,y2
        x1,y1,x2,y2 = rectangle["rect"]
        # center_x, center_y, width, height
        center_x = (x1+x2)/2
        center_y = (y1+y2)/2
        width = abs(x1-x2)
        height = abs(y1-y2)
        # normalize
        center_x /= w
        center_y /= h
        width /= w
        height /= h
​
        # 保留6位小数
        center_x = round(center_x, 6)
        center_y = round(center_y, 6)
        width = round(width, 6)
        height = round(height, 6)
​
​
        # 添加 label_id, center_x, center_y, width, height
        result_list = [label_id, center_x, center_y, width, height]
​
        # 添加 p1_x, p1_y, p1_v, p2_x, p2_y, p2_v
        for point in rectangle["keypoints_list"]:
            x,y = point
            x,y = int(x), int(y)
            # normalize
            x /= w
            y /= h
            # 保留6位小数
            x = round(x, 6)
            y = round(y, 6)

            result_list.extend([x,y,2])
​
        yolo_list.append(result_list)

    return yolo_list
​
​
# 获取所有的图片
img_list = glob.glob("./*.jpg")
​
for img_path in tqdm.tqdm( img_list ):

    img = cv2.imread(img_path)
    print(img_path)
    json_file = img_path.replace('jpg', 'json')
    with open(json_file) as json_file:
        json_data = json.load(json_file)
​
    yolo_list = json_to_yolo(img, json_data)

    yolo_txt_path = img_path.replace('jpg', 'txt')
    with open(yolo_txt_path, "w") as f:
        for yolo in yolo_list:
            for i in range(len(yolo)):
                if i == 0:
                    f.write(str(yolo[i]))
                else:
                    f.write(" " + str(yolo[i]))
            f.write("\n")
​

运行后,把JSON生成TXT。
image.png

5、对yolo格式的标注进行可视化。检查标注是否正确。

先安装jupyter

pip install jupyter
然后在终端输入

jupyter-lab


import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import glob

img_path = './images/1.original_annotated/1 (20).jpg'

plt.figure(figsize=(15,10))
img = cv2.imread(img_path)
plt.imshow(img[:,:,::-1])
plt.axis('off')

yolo_txt_path = img_path.replace('jpg', 'txt')
print(yolo_txt_path)

with open(yolo_txt_path, 'r') as f:
    lines = f.readlines()

lines = [x.strip() for x in lines]

label = np.array([x.split() for x in lines], dtype=np.float32)

# 物体类别
class_list = ["smoke"]
# 类别的颜色
class_color = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255,255,0)]
# 关键点的顺序
keypoint_list = ["head", "tail"]
# 关键点的颜色
keypoint_color = [(255, 0, 0), (0, 255, 0)]

# 绘制检测框
img_copy = img.copy()
h,w = img_copy.shape[:2]
for id,l in enumerate( label ):
    # label_id ,center x,y and width, height
    label_id, cx, cy, bw, bh = l[0:5]
    label_text = class_list[int(label_id)]
    # rescale to image size
    cx *= w
    cy *= h
    bw *= w
    bh *= h

    # draw the bounding box
    xmin = int(cx - bw/2)
    ymin = int(cy - bh/2)
    xmax = int(cx + bw/2)
    ymax = int(cy + bh/2)
    cv2.rectangle(img_copy, (xmin, ymin), (xmax, ymax), class_color[int(label_id)], 2)
    cv2.putText(img_copy, label_text, (xmin, ymin-10), cv2.FONT_HERSHEY_SIMPLEX, 1, class_color[int(label_id)], 2)


# display the image
plt.figure(figsize=(15,10))
plt.imshow(img_copy[:,:,::-1])
plt.axis('off')
# save the image
cv2.imwrite("./tmp.jpg", img_copy)

img_copy = img.copy()
h,w = img_copy.shape[:2]
for id,l in enumerate( label ):
    # label_id ,center x,y and width, height
    label_id, cx, cy, bw, bh = l[0:5]
    label_text = class_list[int(label_id)]
    # rescale to image size
    cx *= w
    cy *= h
    bw *= w
    bh *= h

    # draw the bounding box
    xmin = int(cx - bw/2)
    ymin = int(cy - bh/2)
    xmax = int(cx + bw/2)
    ymax = int(cy + bh/2)
    cv2.rectangle(img_copy, (xmin, ymin), (xmax, ymax), class_color[int(label_id)], 2)
    cv2.putText(img_copy, label_text, (xmin, ymin-10), cv2.FONT_HERSHEY_SIMPLEX, 2, class_color[int(label_id)], 2)

    # draw 17 keypoints, px,py,pv,px,py,pv...
    for i in range(5, len(l), 3):
        px, py, pv = l[i:i+3]
        # rescale to image size
        px *= w
        py *= h
        # puttext the index 
        index = int((i-5)/3)
        # draw the keypoints
        cv2.circle(img_copy, (int(px), int(py)), 10, keypoint_color[int(index)], -1)
        keypoint_text = "{}_{}".format(index, keypoint_list[index])
        cv2.putText(img_copy, keypoint_text, (int(px), int(py)-10), cv2.FONT_HERSHEY_SIMPLEX, 1, keypoint_color[int(index)], 2)

plt.figure(figsize=(15,10))
plt.imshow(img_copy[:,:,::-1])
plt.axis('off')
# save 
cv2.imwrite('./tmp.jpg', img_copy)

image.png

6、训练

组织目录结构,我是直接把datasets目录放到yolov8的根目录下:
image.png
图片 datasets/custom_dataset/images/train/{文件名}.jpg对应的标注文件在 datasets/custom_dataset/labels/train/{文件名}.txt,YOLO会根据这个映射关系自动寻找(images换成labels);
训练集和验证集
images文件夹下有train和val文件夹,分别放置训练集和验证集图片;
labels文件夹有train和val文件夹,分别放置训练集和验证集标签(yolo格式)。
配置文件smoke.yaml

# Ultralytics YOLO 🚀, AGPL-3.0 license
# COCO8-pose dataset (first 8 images from COCO train2017) by Ultralytics
# Example usage: yolo train data=coco8-pose.yaml
# parent
# ├── ultralytics
# └── datasets
#     └── coco8-pose  ← downloads here (1 MB)


# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: ../datasets/custom_dataset  # dataset root dir
train: images/train  # train images (relative to 'path') 4 images
val: images/val  # val images (relative to 'path') 4 images
test:  # test images (optional)

# Keypoints
kpt_shape: [2, 3]  # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
#flip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]

# Classes
names:
  0: smoke

# Download script/URL (optional)
#download: https://ultralytics.com/assets/coco8-pose.zip

編寫訓練train.py,訓練自己的模型

from ultralytics import YOLO

# 加载模型
# model = YOLO('yolov8s-pose.yaml')  # 从头训练
model = YOLO('./yolov8s-pose.pt')  # 使用预训练模型 (recommended for training)
# model = YOLO('yolov8s-pose.yaml').load('yolov8s-pose.pt')  # 从yaml构建网络并从预训练模型中迁移权重

# 训练
results = model.train(data='./smoke_pose.yaml', epochs=300, imgsz=640, workers=0, batch=20, project="pen_bolt", name="s120")

执行python train.py
image.png
会看到开始训练,并在pen_bolt\s1202下看到训练的结果。
image.png

7、测试

由于数据集才20张且只训练了120轮,效果不是太好,所以这里就直接使用图片测试

# 测试图片
from ultralytics import YOLO
import cv2
import numpy as np
import sys
# 读取命令行参数
weight_path = "./pen_bolt/s1202/weights/best.pt"
media_path = "./test.jpg"

# 加载模型
model = YOLO(weight_path )

# 获取类别
objs_labels = model.names  # get class labels
print(objs_labels)


# 类别的颜色
class_color = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255,255,0)]
# 关键点的顺序
keypoint_list = ["head", "tail"]
# 关键点的颜色
keypoint_color = [(255, 0, 0), (0, 255, 0)]

# 读取图片
frame = cv2.imread(media_path)
frame = cv2.resize(frame, (frame.shape[1]//2, frame.shape[0]//2))
# rotate
# 检测
result = list(model(frame, conf=0.3, stream=True))[0]  # inference,如果stream=False,返回的是一个列表,如果stream=True,返回的是一个生成器
boxes = result.boxes  # Boxes object for bbox outputs
boxes = boxes.cpu().numpy()  # convert to numpy array

# 遍历每个框
for box in boxes.data:
    l,t,r,b = box[:4].astype(np.int32) # left, top, right, bottom
    conf, id = box[4:] # confidence, class
    id = int(id)
    # 绘制框
    cv2.rectangle(frame, (l,t), (r,b), class_color[id], 2)
    # 绘制类别+置信度(格式:98.1%)
    cv2.putText(frame, f"{objs_labels[id]} {conf*100:.1f}%", (l, t-10), cv2.FONT_HERSHEY_SIMPLEX, 1, class_color[id], 2)

# 遍历keypoints
keypoints = result.keypoints  # Keypoints object for pose outputs
keypoints = keypoints.cpu().numpy()  # convert to numpy array

# draw keypoints, set first keypoint is red, second is blue
for keypoint in keypoints.data:
    for i in range(len(keypoint)):
        x,y,c = keypoint[i]
        x,y = int(x), int(y)
        cv2.circle(frame, (x,y), 10, keypoint_color[i], -1)
        cv2.putText(frame, f"{keypoint_list[i]}", (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 1, keypoint_color[i], 2)

    if len(keypoint) >= 2:
        # draw arrow line from tail to half between head and tail
        x1,y1,c1 = keypoint[0]
        x2,y2,c2 = keypoint[1]
        center_x, center_y = (x1+x2)/2, (y1+y2)/2
        cv2.arrowedLine(frame, (int(x2),int(y2)), (int(center_x), int(center_y)), (255,0,255), 4, line_type=cv2.LINE_AA, tipLength=0.1)


# save image
cv2.imwrite("result.jpg", frame)
print("save result.jpg")

直接python test.py测试,测试同一方向还好,旋转方向后有点不对,怀疑是数据集和参数问题,但整个过程是OK的。

image.png

相关文章
|
3天前
|
Python 机器学习/深度学习 人工智能
手把手教你从零开始构建并训练你的第一个强化学习智能体:深入浅出Agent项目实战,带你体验编程与AI结合的乐趣
【10月更文挑战第1天】本文通过构建一个简单的强化学习环境,演示了如何创建和训练智能体以完成特定任务。我们使用Python、OpenAI Gym和PyTorch搭建了一个基础的智能体,使其学会在CartPole-v1环境中保持杆子不倒。文中详细介绍了环境设置、神经网络构建及训练过程。此实战案例有助于理解智能体的工作原理及基本训练方法,为更复杂应用奠定基础。首先需安装必要库: ```bash pip install gym torch ``` 接着定义环境并与之交互,实现智能体的训练。通过多个回合的试错学习,智能体逐步优化其策略。这一过程虽从基础做起,但为后续研究提供了良好起点。
16 4
手把手教你从零开始构建并训练你的第一个强化学习智能体:深入浅出Agent项目实战,带你体验编程与AI结合的乐趣
|
1月前
|
机器学习/深度学习 人工智能 PyTorch
AI计算机视觉笔记三十二:LPRNet车牌识别
LPRNet是一种基于Pytorch的高性能、轻量级车牌识别框架,适用于中国及其他国家的车牌识别。该网络无需对字符进行预分割,采用端到端的轻量化设计,结合了squeezenet和inception的思想。其创新点在于去除了RNN,仅使用CNN与CTC Loss,并通过特定的卷积模块提取上下文信息。环境配置包括使用CPU开发板和Autodl训练环境。训练和测试过程需搭建虚拟环境并安装相关依赖,执行训练和测试脚本时可能遇到若干错误,需相应调整代码以确保正确运行。使用官方模型可获得较高的识别准确率,自行训练时建议增加训练轮数以提升效果。
|
1月前
|
人工智能 开发工具 计算机视觉
AI计算机视觉笔记三十:yolov8_obb旋转框训练
本文介绍了如何使用AUTODL环境搭建YOLOv8-obb的训练流程。首先创建虚拟环境并激活,然后通过指定清华源安装ultralytics库。接着下载YOLOv8源码,并使用指定命令开始训练,过程中可能会下载yolov8n.pt文件。训练完成后,可使用相应命令进行预测测试。
|
1月前
|
人工智能 并行计算 测试技术
AI计算机视觉笔记三十一:基于UNetMultiLane的多车道线等识别
该项目基于开源数据集 VIL100 实现了 UNetMultiLane,用于多车道线及车道线类型的识别。数据集中标注了六个车道的车道线及其类型。项目详细记录了从环境搭建到模型训练与测试的全过程,并提供了在 CPU 上进行训练和 ONNX 转换的代码示例。训练过程约需 4 小时完成 50 个 epoch。此外,还实现了视频检测功能,可在视频中实时识别车道线及其类型。
|
4天前
|
机器学习/深度学习 人工智能 自然语言处理
【AI系统】AI在不同领域的应用与行业影响
本文探讨了人工智能在计算机视觉、自然语言处理及音频处理等领域的广泛应用,并展示了其在自动驾驶、安全监控、搜索引擎优化、客户服务、语音识别及多个行业的革新作用,强调了AI基础设施与系统创新对未来社会的影响与价值。
27 1
|
5天前
|
人工智能 Serverless
AI助理精准匹配,为您推荐方案——如何添加一个Stable Difussion图像生成应用
介绍了一种利用AI助手快速获取并搭建Stable Diffusion图像生成应用的方法。用户只需在阿里云官网向AI助手提出需求,即可获得详细的实施方案。随后,按照AI助手提供的方案,通过函数计算部署应用,并进行测试。此过程显著提升了开发效率。
38 1
AI助理精准匹配,为您推荐方案——如何添加一个Stable Difussion图像生成应用
|
2天前
|
机器学习/深度学习 人工智能 算法
探索AI在医疗诊断中的应用与挑战
【10月更文挑战第2天】本文深入探讨了人工智能技术在医疗诊断领域的应用,以及其带来的变革。通过分析AI技术的工作原理和实际应用案例,揭示了AI在提高诊断准确率、优化治疗流程等方面的巨大潜力。同时,文章也指出了AI在医疗领域面临的伦理、法律和技术等挑战,并讨论了未来可能的发展方向。
16 7
|
1天前
|
机器学习/深度学习 人工智能 自然语言处理
AI技术在自然语言处理中的应用与挑战
【10月更文挑战第3天】本文将探讨AI技术在自然语言处理(NLP)领域的应用及其面临的挑战。我们将分析NLP的基本原理,介绍AI技术如何推动NLP的发展,并讨论当前的挑战和未来的趋势。通过本文,读者将了解AI技术在NLP中的重要性,以及如何利用这些技术解决实际问题。
|
1天前
|
人工智能 缓存 Java
深入解析Spring AI框架:在Java应用中实现智能化交互的关键
【10月更文挑战第12天】Spring AI 是 Spring 框架家族的新成员,旨在满足 Java 应用程序对人工智能集成的需求。它支持自然语言处理、图像识别等多种 AI 技术,并提供与云服务(如 OpenAI、Azure Cognitive Services)及本地模型的无缝集成。通过简单的配置和编码,开发者可轻松实现 AI 功能,同时应对模型切换、数据安全及性能优化等挑战。
|
2天前
|
机器学习/深度学习 人工智能 TensorFlow
AI技术在医疗诊断中的应用与挑战
【10月更文挑战第2天】本文将探讨AI技术在医疗诊断领域的应用及其带来的挑战。我们将通过实际案例和代码示例,展示AI如何帮助医生更准确地诊断疾病,并讨论其面临的伦理和法律问题。
10 4

热门文章

最新文章