AI计算机视觉笔记二十六:YOLOV8自训练关键点检测

简介: 本文档详细记录了使用YOLOv8训练关键点检测模型的过程。首先通过清华源安装YOLOv8,并验证安装。接着通过示例权重文件与测试图片`bus.jpg`演示预测流程。为准备训练数据,文档介绍了如何使用`labelme`标注工具进行关键点标注,并提供了一个Python脚本`labelme2yolo.py`将标注结果从JSON格式转换为YOLO所需的TXT格式。随后,通过Jupyter Notebook可视化标注结果确保准确性。最后,文档展示了如何组织数据集目录结构,并提供了训练与测试代码示例,包括配置文件`smoke.yaml`及训练脚本`train.py`,帮助读者完成自定义模型的训练与评估。

记录学习YOLOV8过程,自训练关键点检测模型。

清华源:-i https://mirror.baidu.com/pypi/simple

1、yolov8安装

git clone https://github.com/ultralytics/ultralytics
​
cd ultralytics
​
pip install -e .
AI 代码解读

安装成功后,使用命令 yolo 简单看下版本

(yolov8) llh@anhao:/$ yolo version
​
8.0.206
AI 代码解读

2、简单测试

下载权重文件

GitHub - ultralytics/ultralytics: NEW - YOLOv8 🚀 in PyTorch > ONNX > OpenVINO > CoreML > TFLite
image.png
直接点击下载。

获取测试图片的文件在ultralytics\assets目录有,使用的是的里面的bus.jpg测试。

使用 yolo 命令进行测试

yolo detect predict model=./yolov8n.pt source=./bus.jpg
AI 代码解读

image.png
输出在runs/detect/predict/目录下。
image.png

3、安装labelme

pip install labelme
AI 代码解读

直接在终端运行labelme打开软件
image.png
先取消“保存图片数据”(减少标注文件大小);在文件下
image.png
打开文件目录

image.png
接下来标注

先标注检测框,检测框用Create Rectangle(Ctrl+N)

填写类别名称

填写group_id,用于匹配后续标注的关键点,以当前画面中出现的顺序标注即可。
image.png
标注关键点,检测关键点用Create Point

按关键点顺序标注,如我们的顺序是head、tail,不可以错;

填写关键点名称,如这里是head;

填写关键点所在物体的group_id,匹配检测框

image.png
注意,如果多个类型需要填写group_id,group_id要匹配检测框

4、把JSON转成TXT

使用labelme标注生成的JSON文件,不能直接训练,所以需要把JSON文件转成TXT文件。

打开一个标注json文件,内容大致如下

version": "5.3.1",
  "flags": {},
  "shapes": [
    {
      "label": "smoke",
      "points": [
        [
          389.0,
          72.5
        ],
        [
          957.0,
          114.5
        ],
        [
          949.0,
          192.5
        ],
        [
          379.0,
          162.5
        ]
      ],
      "group_id": null,
      "description": "",
      "shape_type": "polygon",
      "flags": {}
    },
AI 代码解读

但是yolov8 要求的标注文件长这样

0 0.662641 0.494385 0.674719 0.988771 0.717187 0.189583 2.000000 0.798438 0.127083 2.000000 0.701562 0.091667 2.000000 0.921875 0.118750 2.000000 0.000000 0.000000 0.000000 0.971875 0.379167 2.000000 0.554688 0.262500 2.000000 0.000000 0.000000 0.000000 0.367188 0.427083 2.000000 0.767188 0.772917 2.000000 0.421875 0.500000 2.000000 0.829688 0.960417 1.000000 0.517188 0.881250 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000

根据这些规则,我们可以写一个转换脚本,将labelme标注的json格式转为yolo格式

labelme2yolo.py


# 将labelme标注的json文件转为yolo格式
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import glob
import json
import tqdm
​
​
# 物体类别
class_list = ["smoke"]
# 关键点的顺序
keypoint_list = ["head", "tail"]
​
​
def json_to_yolo(img_data,json_data):
    h,w = img_data.shape[:2]
    # 步骤:
    # 1. 找出所有的矩形,记录下矩形的坐标,以及对应group_id
    # 2. 遍历所有的head和tail,记下点的坐标,以及对应group_id,加入到对应的矩形中
    # 3. 转为yolo格式
​
    rectangles = {}
    # 遍历初始化
    for shape in json_data["shapes"]:
        label = shape["label"] # pen, head, tail
        group_id = shape["group_id"] # 0, 1, 2, ...
        points = shape["points"] # x,y coordinates
        shape_type = shape["shape_type"]
​
        # 只处理矩形
        if shape_type == "rectangle":
            if group_id not in rectangles:
                rectangles[group_id] = {
                    "label": label,
                    "rect": points[0] + points[1],  # Rectangle [x1, y1, x2, y2]
                    "keypoints_list": []
                }
    # 遍历更新,将点加入对应group_id的矩形中
    for keypoint in keypoint_list:
        for shape in json_data["shapes"]:
            label = shape["label"]
            group_id = shape["group_id"]
            points = shape["points"]
            # 如果匹配到了对应的keypoint
            if label == keypoint:
                rectangles[group_id]["keypoints_list"].append(points[0])

    # 转为yolo格式
    yolo_list = []
    for id, rectangle in rectangles.items():
        result_list  = []
        label_id = class_list.index(rectangle["label"])
        # x1,y1,x2,y2
        x1,y1,x2,y2 = rectangle["rect"]
        # center_x, center_y, width, height
        center_x = (x1+x2)/2
        center_y = (y1+y2)/2
        width = abs(x1-x2)
        height = abs(y1-y2)
        # normalize
        center_x /= w
        center_y /= h
        width /= w
        height /= h
​
        # 保留6位小数
        center_x = round(center_x, 6)
        center_y = round(center_y, 6)
        width = round(width, 6)
        height = round(height, 6)
​
​
        # 添加 label_id, center_x, center_y, width, height
        result_list = [label_id, center_x, center_y, width, height]
​
        # 添加 p1_x, p1_y, p1_v, p2_x, p2_y, p2_v
        for point in rectangle["keypoints_list"]:
            x,y = point
            x,y = int(x), int(y)
            # normalize
            x /= w
            y /= h
            # 保留6位小数
            x = round(x, 6)
            y = round(y, 6)

            result_list.extend([x,y,2])
​
        yolo_list.append(result_list)

    return yolo_list
​
​
# 获取所有的图片
img_list = glob.glob("./*.jpg")
​
for img_path in tqdm.tqdm( img_list ):

    img = cv2.imread(img_path)
    print(img_path)
    json_file = img_path.replace('jpg', 'json')
    with open(json_file) as json_file:
        json_data = json.load(json_file)
​
    yolo_list = json_to_yolo(img, json_data)

    yolo_txt_path = img_path.replace('jpg', 'txt')
    with open(yolo_txt_path, "w") as f:
        for yolo in yolo_list:
            for i in range(len(yolo)):
                if i == 0:
                    f.write(str(yolo[i]))
                else:
                    f.write(" " + str(yolo[i]))
            f.write("\n")
​
AI 代码解读

运行后,把JSON生成TXT。
image.png

5、对yolo格式的标注进行可视化。检查标注是否正确。

先安装jupyter

pip install jupyter
然后在终端输入

jupyter-lab


import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import glob

img_path = './images/1.original_annotated/1 (20).jpg'

plt.figure(figsize=(15,10))
img = cv2.imread(img_path)
plt.imshow(img[:,:,::-1])
plt.axis('off')

yolo_txt_path = img_path.replace('jpg', 'txt')
print(yolo_txt_path)

with open(yolo_txt_path, 'r') as f:
    lines = f.readlines()

lines = [x.strip() for x in lines]

label = np.array([x.split() for x in lines], dtype=np.float32)

# 物体类别
class_list = ["smoke"]
# 类别的颜色
class_color = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255,255,0)]
# 关键点的顺序
keypoint_list = ["head", "tail"]
# 关键点的颜色
keypoint_color = [(255, 0, 0), (0, 255, 0)]

# 绘制检测框
img_copy = img.copy()
h,w = img_copy.shape[:2]
for id,l in enumerate( label ):
    # label_id ,center x,y and width, height
    label_id, cx, cy, bw, bh = l[0:5]
    label_text = class_list[int(label_id)]
    # rescale to image size
    cx *= w
    cy *= h
    bw *= w
    bh *= h

    # draw the bounding box
    xmin = int(cx - bw/2)
    ymin = int(cy - bh/2)
    xmax = int(cx + bw/2)
    ymax = int(cy + bh/2)
    cv2.rectangle(img_copy, (xmin, ymin), (xmax, ymax), class_color[int(label_id)], 2)
    cv2.putText(img_copy, label_text, (xmin, ymin-10), cv2.FONT_HERSHEY_SIMPLEX, 1, class_color[int(label_id)], 2)


# display the image
plt.figure(figsize=(15,10))
plt.imshow(img_copy[:,:,::-1])
plt.axis('off')
# save the image
cv2.imwrite("./tmp.jpg", img_copy)

img_copy = img.copy()
h,w = img_copy.shape[:2]
for id,l in enumerate( label ):
    # label_id ,center x,y and width, height
    label_id, cx, cy, bw, bh = l[0:5]
    label_text = class_list[int(label_id)]
    # rescale to image size
    cx *= w
    cy *= h
    bw *= w
    bh *= h

    # draw the bounding box
    xmin = int(cx - bw/2)
    ymin = int(cy - bh/2)
    xmax = int(cx + bw/2)
    ymax = int(cy + bh/2)
    cv2.rectangle(img_copy, (xmin, ymin), (xmax, ymax), class_color[int(label_id)], 2)
    cv2.putText(img_copy, label_text, (xmin, ymin-10), cv2.FONT_HERSHEY_SIMPLEX, 2, class_color[int(label_id)], 2)

    # draw 17 keypoints, px,py,pv,px,py,pv...
    for i in range(5, len(l), 3):
        px, py, pv = l[i:i+3]
        # rescale to image size
        px *= w
        py *= h
        # puttext the index 
        index = int((i-5)/3)
        # draw the keypoints
        cv2.circle(img_copy, (int(px), int(py)), 10, keypoint_color[int(index)], -1)
        keypoint_text = "{}_{}".format(index, keypoint_list[index])
        cv2.putText(img_copy, keypoint_text, (int(px), int(py)-10), cv2.FONT_HERSHEY_SIMPLEX, 1, keypoint_color[int(index)], 2)

plt.figure(figsize=(15,10))
plt.imshow(img_copy[:,:,::-1])
plt.axis('off')
# save 
cv2.imwrite('./tmp.jpg', img_copy)
AI 代码解读

image.png

6、训练

组织目录结构,我是直接把datasets目录放到yolov8的根目录下:
image.png
图片 datasets/custom_dataset/images/train/{文件名}.jpg对应的标注文件在 datasets/custom_dataset/labels/train/{文件名}.txt,YOLO会根据这个映射关系自动寻找(images换成labels);
训练集和验证集
images文件夹下有train和val文件夹,分别放置训练集和验证集图片;
labels文件夹有train和val文件夹,分别放置训练集和验证集标签(yolo格式)。
配置文件smoke.yaml

# Ultralytics YOLO 🚀, AGPL-3.0 license
# COCO8-pose dataset (first 8 images from COCO train2017) by Ultralytics
# Example usage: yolo train data=coco8-pose.yaml
# parent
# ├── ultralytics
# └── datasets
#     └── coco8-pose  ← downloads here (1 MB)


# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: ../datasets/custom_dataset  # dataset root dir
train: images/train  # train images (relative to 'path') 4 images
val: images/val  # val images (relative to 'path') 4 images
test:  # test images (optional)

# Keypoints
kpt_shape: [2, 3]  # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
#flip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]

# Classes
names:
  0: smoke

# Download script/URL (optional)
#download: https://ultralytics.com/assets/coco8-pose.zip
AI 代码解读

編寫訓練train.py,訓練自己的模型

from ultralytics import YOLO

# 加载模型
# model = YOLO('yolov8s-pose.yaml')  # 从头训练
model = YOLO('./yolov8s-pose.pt')  # 使用预训练模型 (recommended for training)
# model = YOLO('yolov8s-pose.yaml').load('yolov8s-pose.pt')  # 从yaml构建网络并从预训练模型中迁移权重

# 训练
results = model.train(data='./smoke_pose.yaml', epochs=300, imgsz=640, workers=0, batch=20, project="pen_bolt", name="s120")
AI 代码解读

执行python train.py
image.png
会看到开始训练,并在pen_bolt\s1202下看到训练的结果。
image.png

7、测试

由于数据集才20张且只训练了120轮,效果不是太好,所以这里就直接使用图片测试

# 测试图片
from ultralytics import YOLO
import cv2
import numpy as np
import sys
# 读取命令行参数
weight_path = "./pen_bolt/s1202/weights/best.pt"
media_path = "./test.jpg"

# 加载模型
model = YOLO(weight_path )

# 获取类别
objs_labels = model.names  # get class labels
print(objs_labels)


# 类别的颜色
class_color = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255,255,0)]
# 关键点的顺序
keypoint_list = ["head", "tail"]
# 关键点的颜色
keypoint_color = [(255, 0, 0), (0, 255, 0)]

# 读取图片
frame = cv2.imread(media_path)
frame = cv2.resize(frame, (frame.shape[1]//2, frame.shape[0]//2))
# rotate
# 检测
result = list(model(frame, conf=0.3, stream=True))[0]  # inference,如果stream=False,返回的是一个列表,如果stream=True,返回的是一个生成器
boxes = result.boxes  # Boxes object for bbox outputs
boxes = boxes.cpu().numpy()  # convert to numpy array

# 遍历每个框
for box in boxes.data:
    l,t,r,b = box[:4].astype(np.int32) # left, top, right, bottom
    conf, id = box[4:] # confidence, class
    id = int(id)
    # 绘制框
    cv2.rectangle(frame, (l,t), (r,b), class_color[id], 2)
    # 绘制类别+置信度(格式:98.1%)
    cv2.putText(frame, f"{objs_labels[id]} {conf*100:.1f}%", (l, t-10), cv2.FONT_HERSHEY_SIMPLEX, 1, class_color[id], 2)

# 遍历keypoints
keypoints = result.keypoints  # Keypoints object for pose outputs
keypoints = keypoints.cpu().numpy()  # convert to numpy array

# draw keypoints, set first keypoint is red, second is blue
for keypoint in keypoints.data:
    for i in range(len(keypoint)):
        x,y,c = keypoint[i]
        x,y = int(x), int(y)
        cv2.circle(frame, (x,y), 10, keypoint_color[i], -1)
        cv2.putText(frame, f"{keypoint_list[i]}", (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 1, keypoint_color[i], 2)

    if len(keypoint) >= 2:
        # draw arrow line from tail to half between head and tail
        x1,y1,c1 = keypoint[0]
        x2,y2,c2 = keypoint[1]
        center_x, center_y = (x1+x2)/2, (y1+y2)/2
        cv2.arrowedLine(frame, (int(x2),int(y2)), (int(center_x), int(center_y)), (255,0,255), 4, line_type=cv2.LINE_AA, tipLength=0.1)


# save image
cv2.imwrite("result.jpg", frame)
print("save result.jpg")
AI 代码解读

直接python test.py测试,测试同一方向还好,旋转方向后有点不对,怀疑是数据集和参数问题,但整个过程是OK的。

image.png

目录
打赏
0
2
2
2
46
分享
相关文章
【一步步开发AI运动APP】七、自定义姿态动作识别检测——之规则配置检测
本文介绍了如何通过【一步步开发AI运动APP】系列博文,利用自定义姿态识别检测技术开发高性能的AI运动应用。核心内容包括:1) 自定义姿态识别检测,满足人像入镜、动作开始/停止等需求;2) Pose-Calc引擎详解,支持角度匹配、逻辑运算等多种人体分析规则;3) 姿态检测规则编写与执行方法;4) 完整示例展示左右手平举姿态检测。通过这些技术,开发者可轻松实现定制化运动分析功能。
AI驱动的幼儿跌倒检测——视频安全系统的技术解析
幼儿跌倒检测系统基于AI视频技术,融合人体姿态识别与实时报警功能,为幼儿园安全管理提供智能化解决方案。系统通过YOLOv9、OpenPose等算法实现高精度跌倒检测(准确率达98%),结合LSTM时间序列分析减少误报,支持目标分类区分幼儿与成人,并具备事件存储、实时通知及开源部署优势。其高效、灵活、隐私合规的特点显著提升安全管理效率,助力优化园所运营。
66 0
AI驱动的幼儿跌倒检测——视频安全系统的技术解析
用Qwen3+MCPs实现AI自动发布小红书笔记!支持图文和视频
魔搭自动发布小红书MCP,是魔搭开发者小伙伴实现的小红书笔记自动发布器,可以通过这个MCP自动完成小红书标题、内容和图片的发布。
438 39
“服务器老被黑?那是你没上AI哨兵!”——聊聊基于AI的网络攻击检测那些事儿
“服务器老被黑?那是你没上AI哨兵!”——聊聊基于AI的网络攻击检测那些事儿
80 12
Windows版来啦!Qwen3+MCPs,用AI自动发布小红书图文/视频笔记!
上一篇用 Qwen3+MCPs实现AI自动发小红书的最佳实践 有超多小伙伴关注,同时也排队在蹲Windows版本的教程。
169 1
FinGPT:华尔街颤抖!用股价训练AI,开源金融大模型预测股价准确率碾压分析师,量化交易新利器
FinGPT是基于Transformer架构的开源金融大模型,通过RLHF技术和实时数据处理能力,支持情感分析、市场预测等核心功能,其LoRA微调技术大幅降低训练成本。
145 12
FinGPT:华尔街颤抖!用股价训练AI,开源金融大模型预测股价准确率碾压分析师,量化交易新利器
AI“捕风捉影”:深度学习如何让网络事件检测更智能?
AI“捕风捉影”:深度学习如何让网络事件检测更智能?
53 8
【一步步开发AI运动APP】八、自定义姿态动作识别检测——之姿态相似度比较
本文介绍了如何通过姿态相似度比较技术简化AI运动应用开发。相比手动配置规则,插件`pose-calc`提供的姿态相似度比较器可快速评估两组人体关键点的整体与局部相似度,降低开发者工作量。文章还展示了在`uni-app`框架下调用姿态比较器的示例代码,并提供了桌面辅助工具以帮助提取标准动作样本,助力开发者打造性能更优、体验更好的AI运动APP。
17.1K star!两小时就能训练出专属与自己的个性化小模型,这个开源项目让AI触手可及!
🔥「只需一张消费级显卡,2小时完成26M参数GPT训练!」 🌟「从零构建中文大模型的最佳实践指南」 🚀「兼容OpenAI API,轻松接入各类AI应用平台」
小鹏汽车选用阿里云PolarDB,开启AI大模型训练新时代
PolarDB-PG云原生分布式数据库不仅提供了无限的扩展能力,还借助丰富的PostgreSQL生态系统,统一了后台技术栈,极大地简化了运维工作。这种强大的组合不仅提高了系统的稳定性和性能,还为小鹏汽车大模型训练的数据管理带来了前所未有的灵活性和效率。

热门文章

最新文章

AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等