yolo-world 源码解析(四)(2)

本文涉及的产品
云解析 DNS,旗舰版 1个月
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
简介: yolo-world 源码解析(四)

yolo-world 源码解析(四)(1)https://developer.aliyun.com/article/1483875

.\YOLO-World\image_demo.py

# 版权声明
# 导入必要的库
import os
import cv2
import argparse
import os.path as osp
import torch
from mmengine.config import Config, DictAction
from mmengine.runner import Runner
from mmengine.runner.amp import autocast
from mmengine.dataset import Compose
from mmengine.utils import ProgressBar
from mmyolo.registry import RUNNERS
# 定义BOUNDING_BOX_ANNOTATOR对象
BOUNDING_BOX_ANNOTATOR = None
# 定义LABEL_ANNOTATOR对象
LABEL_ANNOTATOR = None
# 解析命令行参数
def parse_args():
    parser = argparse.ArgumentParser(description='YOLO-World Demo')
    # 添加命令行参数
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument('image', help='image path, include image file or dir.')
    parser.add_argument(
        'text',
        help='text prompts, including categories separated by a comma or a txt file with each line as a prompt.'
    )
    parser.add_argument('--topk',
                        default=100,
                        type=int,
                        help='keep topk predictions.')
    parser.add_argument('--threshold',
                        default=0.0,
                        type=float,
                        help='confidence score threshold for predictions.')
    parser.add_argument('--device',
                        default='cuda:0',
                        help='device used for inference.')
    parser.add_argument('--show',
                        action='store_true',
                        help='show the detection results.')
    parser.add_argument('--annotation',
                        action='store_true',
                        help='save the annotated detection results as yolo text format.')
    parser.add_argument('--amp',
                        action='store_true',
                        help='use mixed precision for inference.')
    # 添加一个名为'--output-dir'的命令行参数,用于指定保存输出的目录,默认为'demo_outputs'
    parser.add_argument('--output-dir',
                        default='demo_outputs',
                        help='the directory to save outputs')
    # 添加一个名为'--cfg-options'的命令行参数,用于覆盖配置文件中的一些设置,支持键值对形式的参数
    # 如果要覆盖的值是列表,则应该以 key="[a,b]" 或 key=a,b 的格式提供
    # 还支持嵌套列表/元组值,例如 key="[(a,b),(c,d)]"
    # 注意引号是必要的,不允许有空格
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    # 解析命令行参数
    args = parser.parse_args()
    # 返回解析后的参数
    return args
# 推断检测器,运行模型进行推断
def inference_detector(runner,
                       image_path,
                       texts,
                       max_dets,
                       score_thr,
                       output_dir,
                       use_amp=False,
                       show=False,
                       annotation=False):
    # 创建包含图像信息的字典
    data_info = dict(img_id=0, img_path=image_path, texts=texts)
    # 运行数据处理管道
    data_info = runner.pipeline(data_info)
    # 创建包含数据批次信息的字典
    data_batch = dict(inputs=data_info['inputs'].unsqueeze(0),
                      data_samples=[data_info['data_samples']])
    # 使用自动混合精度和禁用梯度计算
    with autocast(enabled=use_amp), torch.no_grad():
        # 运行模型的测试步骤
        output = runner.model.test_step(data_batch)[0]
        pred_instances = output.pred_instances
        # 通过设置阈值过滤预测实例
        pred_instances = pred_instances[
            pred_instances.scores.float() > score_thr]
    # 如果预测实例数量超过最大检测数
    if len(pred_instances.scores) > max_dets:
        # 选择得分最高的前 max_dets 个预测实例
        indices = pred_instances.scores.float().topk(max_dets)[1]
        pred_instances = pred_instances[indices]
    # 将预测实例转换为 numpy 数组
    pred_instances = pred_instances.cpu().numpy()
    # 定义检测对象
    detections = None
    # 为每个检测结果添加标签
    labels = [
        f"{texts[class_id][0]} {confidence:0.2f}" for class_id, confidence in
        zip(detections.class_id, detections.confidence)
    ]
    # 读取图像
    image = cv2.imread(image_path)
    anno_image = image.copy()
    # 在图像上绘制边界框
    image = BOUNDING_BOX_ANNOTATOR.annotate(image, detections)
    # 在图像上添加标签
    image = LABEL_ANNOTATOR.annotate(image, detections, labels=labels)
    # 将标记后的图像保存到输出目录
    cv2.imwrite(osp.join(output_dir, osp.basename(image_path)), image)
    # 如果有注释
    if annotation:
        # 创建空字典用于存储图像和注释
        images_dict = {}
        annotations_dict = {}
        # 将图像路径的基本名称作为键,注释图像作为值存储在图像字典中
        images_dict[osp.basename(image_path)] = anno_image
        # 将图像路径的基本名称作为键,检测结果作为值存储在注释字典中
        annotations_dict[osp.basename(image_path)] = detections
        
        # 创建一个名为ANNOTATIONS_DIRECTORY的目录,如果目录已存在则不创建
        ANNOTATIONS_DIRECTORY =  os.makedirs(r"./annotations", exist_ok=True)
        # 设置最小图像面积百分比
        MIN_IMAGE_AREA_PERCENTAGE = 0.002
        # 设置最大图像面积百分比
        MAX_IMAGE_AREA_PERCENTAGE = 0.80
        # 设置近似百分比
        APPROXIMATION_PERCENTAGE = 0.75
        
        # 创建一个DetectionDataset对象,传入类别、图像字典和注释字典,然后转换为YOLO格式
        sv.DetectionDataset(
            classes=texts,
            images=images_dict,
            annotations=annotations_dict
        ).as_yolo(
            annotations_directory_path=ANNOTATIONS_DIRECTORY,
            min_image_area_percentage=MIN_IMAGE_AREA_PERCENTAGE,
            max_image_area_percentage=MAX_IMAGE_AREA_PERCENTAGE,
            approximation_percentage=APPROXIMATION_PERCENTAGE
        )
    # 如果需要展示图像
    if show:
        # 在窗口中展示图像,提供窗口名称
        cv2.imshow('Image', image)
        # 等待按键输入,0表示一直等待
        k = cv2.waitKey(0)
        # 如果按下ESC键(ASCII码为27),关闭所有窗口
        if k == 27:
            cv2.destroyAllWindows()
if __name__ == '__main__':
    # 解析命令行参数
    args = parse_args()
    # 加载配置文件
    cfg = Config.fromfile(args.config)
    # 如果有额外的配置选项,则合并到配置文件中
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    # 设置工作目录为当前目录下的 work_dirs 文件夹中,使用配置文件名作为子目录名
    cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])
    # 加载模型检查点
    cfg.load_from = args.checkpoint
    # 根据配置文件中是否包含 runner_type 字段来选择不同的 Runner 类型
    if 'runner_type' not in cfg:
        runner = Runner.from_cfg(cfg)
    else:
        runner = RUNNERS.build(cfg)
    # 加载文本数据
    if args.text.endswith('.txt'):
        with open(args.text) as f:
            lines = f.readlines()
        # 将文本数据转换为列表形式
        texts = [[t.rstrip('\r\n')] for t in lines] + [[' ']]
    else:
        # 将命令行参数中的文本数据转换为列表形式
        texts = [[t.strip()] for t in args.text.split(',')] + [[' ']]
    # 设置输出目录
    output_dir = args.output_dir
    # 如果输出目录不存在,则创建
    if not osp.exists(output_dir):
        os.mkdir(output_dir)
    # 在运行之前调用钩子函数
    runner.call_hook('before_run')
    # 加载或恢复模型
    runner.load_or_resume()
    # 获取数据处理流程
    pipeline = cfg.test_dataloader.dataset.pipeline
    runner.pipeline = Compose(pipeline)
    # 设置模型为评估模式
    runner.model.eval()
    # 检查输入的图像路径是否为文件夹
    if not osp.isfile(args.image):
        # 获取文件夹中所有以 .png 或 .jpg 结尾的图像文件路径
        images = [
            osp.join(args.image, img) for img in os.listdir(args.image)
            if img.endswith('.png') or img.endswith('.jpg')
        ]
    else:
        # 将输入的图像路径转换为列表形式
        images = [args.image]
    # 创建进度条对象,用于显示处理进度
    progress_bar = ProgressBar(len(images))
    # 遍历每张图像进行目标检测
    for image_path in images:
        # 调用目标检测函数进行推理
        inference_detector(runner,
                           image_path,
                           texts,
                           args.topk,
                           args.threshold,
                           output_dir=output_dir,
                           use_amp=args.amp,
                           show=args.show,
                           annotation=args.annotation)
        # 更新进度条
        progress_bar.update()

.\YOLO-World\tools\test.py

# 版权声明
# 导入必要的库
import argparse
import os
import os.path as osp
# 导入自定义模块
from mmdet.engine.hooks.utils import trigger_visualization_hook
from mmengine.config import Config, ConfigDict, DictAction
from mmengine.evaluator import DumpResults
from mmengine.runner import Runner
# 导入自定义模块
from mmyolo.registry import RUNNERS
from mmyolo.utils import is_metainfo_lower
# 定义解析命令行参数的函数
def parse_args():
    # 创建 ArgumentParser 对象,设置描述信息
    parser = argparse.ArgumentParser(
        description='MMYOLO test (and eval) a model')
    # 添加命令行参数
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument(
        '--work-dir',
        help='the directory to save the file containing evaluation metrics')
    parser.add_argument(
        '--out',
        type=str,
        help='output result file (must be a .pkl file) in pickle format')
    parser.add_argument(
        '--json-prefix',
        type=str,
        help='the prefix of the output json file without perform evaluation, '
        'which is useful when you want to format the result to a specific '
        'format and submit it to the test server')
    parser.add_argument(
        '--tta',
        action='store_true',
        help='Whether to use test time augmentation')
    parser.add_argument(
        '--show', action='store_true', help='show prediction results')
    parser.add_argument(
        '--deploy',
        action='store_true',
        help='Switch model to deployment mode')
    parser.add_argument(
        '--show-dir',
        help='directory where painted images will be saved. '
        'If specified, it will be automatically saved '
        'to the work_dir/timestamp/show_dir')
    parser.add_argument(
        '--wait-time', type=float, default=2, help='the interval of show (s)')
    # 添加一个命令行参数,用于覆盖配置文件中的一些设置,参数为字典类型,使用自定义的DictAction处理
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    
    # 添加一个命令行参数,用于指定作业启动器的类型,可选值为['none', 'pytorch', 'slurm', 'mpi'],默认为'none'
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    
    # 添加一个命令行参数,用于指定本地进程的排名,默认为0
    parser.add_argument('--local_rank', type=int, default=0)
    
    # 解析命令行参数并返回结果
    args = parser.parse_args()
    
    # 如果环境变量中没有'LOCAL_RANK',则将命令行参数中的local_rank值赋给'LOCAL_RANK'
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)
    
    # 返回解析后的命令行参数
    return args
def main():
    # 解析命令行参数
    args = parse_args()
    # 加载配置文件
    cfg = Config.fromfile(args.config)
    # 用 cfg.key 的值替换 ${key}
    # cfg = replace_cfg_vals(cfg)
    cfg.launcher = args.launcher
    if args.cfg_options is not None:
        # 根据命令行参数更新配置
        cfg.merge_from_dict(args.cfg_options)
    # 确定工作目录的优先级:CLI > 配置文件中的段 > 文件名
    if args.work_dir is not None:
        # 如果 args.work_dir 不为 None,则根据 CLI 参数更新配置
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # 如果 cfg.work_dir 为 None,则使用配置文件名作为默认工作目录
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])
    # 加载模型参数
    cfg.load_from = args.checkpoint
    if args.show or args.show_dir:
        # 触发可视化钩子
        cfg = trigger_visualization_hook(cfg, args)
    if args.deploy:
        # 添加部署钩子
        cfg.custom_hooks.append(dict(type='SwitchToDeployHook'))
    # 将 `format_only` 和 `outfile_prefix` 添加到配置中
    if args.json_prefix is not None:
        cfg_json = {
            'test_evaluator.format_only': True,
            'test_evaluator.outfile_prefix': args.json_prefix
        }
        cfg.merge_from_dict(cfg_json)
    # 确定自定义元信息字段是否全部为小写
    is_metainfo_lower(cfg)
    # 如果启用了测试时间增强(TTA),则需要检查配置中是否包含必要的参数
    if args.tta:
        # 检查配置中是否包含 tta_model 和 tta_pipeline,否则无法使用 TTA
        assert 'tta_model' in cfg, 'Cannot find ``tta_model`` in config.' \
                                   " Can't use tta !"
        assert 'tta_pipeline' in cfg, 'Cannot find ``tta_pipeline`` ' \
                                      "in config. Can't use tta !"
        # 将 tta_model 合并到 model 配置中
        cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model)
        test_data_cfg = cfg.test_dataloader.dataset
        while 'dataset' in test_data_cfg:
            test_data_cfg = test_data_cfg['dataset']
        # batch_shapes_cfg 会强制控制输出图像的大小,与 TTA 不兼容
        if 'batch_shapes_cfg' in test_data_cfg:
            test_data_cfg.batch_shapes_cfg = None
        test_data_cfg.pipeline = cfg.tta_pipeline
    # 根据配置构建 Runner 对象
    if 'runner_type' not in cfg:
        # 构建默认的 Runner
        runner = Runner.from_cfg(cfg)
    else:
        # 从注册表中构建自定义的 Runner,如果配置中设置了 runner_type
        runner = RUNNERS.build(cfg)
    # 添加 `DumpResults` 虚拟指标
    if args.out is not None:
        # 确保输出文件是 pkl 或 pickle 格式
        assert args.out.endswith(('.pkl', '.pickle')), \
            'The dump file must be a pkl file.'
        runner.test_evaluator.metrics.append(
            DumpResults(out_file_path=args.out))
    # 开始测试
    runner.test()
# 如果当前脚本被直接执行,则调用主函数
if __name__ == '__main__':
    main()

yolo-world 源码解析(四)(3)https://developer.aliyun.com/article/1483877

相关文章
|
2月前
|
监控 Java 应用服务中间件
高级java面试---spring.factories文件的解析源码API机制
【11月更文挑战第20天】Spring Boot是一个用于快速构建基于Spring框架的应用程序的开源框架。它通过自动配置、起步依赖和内嵌服务器等特性,极大地简化了Spring应用的开发和部署过程。本文将深入探讨Spring Boot的背景历史、业务场景、功能点以及底层原理,并通过Java代码手写模拟Spring Boot的启动过程,特别是spring.factories文件的解析源码API机制。
92 2
|
14天前
|
存储 设计模式 算法
【23种设计模式·全精解析 | 行为型模式篇】11种行为型模式的结构概述、案例实现、优缺点、扩展对比、使用场景、源码解析
行为型模式用于描述程序在运行时复杂的流程控制,即描述多个类或对象之间怎样相互协作共同完成单个对象都无法单独完成的任务,它涉及算法与对象间职责的分配。行为型模式分为类行为模式和对象行为模式,前者采用继承机制来在类间分派行为,后者采用组合或聚合在对象间分配行为。由于组合关系或聚合关系比继承关系耦合度低,满足“合成复用原则”,所以对象行为模式比类行为模式具有更大的灵活性。 行为型模式分为: • 模板方法模式 • 策略模式 • 命令模式 • 职责链模式 • 状态模式 • 观察者模式 • 中介者模式 • 迭代器模式 • 访问者模式 • 备忘录模式 • 解释器模式
【23种设计模式·全精解析 | 行为型模式篇】11种行为型模式的结构概述、案例实现、优缺点、扩展对比、使用场景、源码解析
|
14天前
|
设计模式 存储 安全
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
结构型模式描述如何将类或对象按某种布局组成更大的结构。它分为类结构型模式和对象结构型模式,前者采用继承机制来组织接口和类,后者釆用组合或聚合来组合对象。由于组合关系或聚合关系比继承关系耦合度低,满足“合成复用原则”,所以对象结构型模式比类结构型模式具有更大的灵活性。 结构型模式分为以下 7 种: • 代理模式 • 适配器模式 • 装饰者模式 • 桥接模式 • 外观模式 • 组合模式 • 享元模式
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
|
14天前
|
设计模式 存储 安全
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
创建型模式的主要关注点是“怎样创建对象?”,它的主要特点是"将对象的创建与使用分离”。这样可以降低系统的耦合度,使用者不需要关注对象的创建细节。创建型模式分为5种:单例模式、工厂方法模式抽象工厂式、原型模式、建造者模式。
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
|
2月前
|
缓存 监控 Java
Java线程池提交任务流程底层源码与源码解析
【11月更文挑战第30天】嘿,各位技术爱好者们,今天咱们来聊聊Java线程池提交任务的底层源码与源码解析。作为一个资深的Java开发者,我相信你一定对线程池并不陌生。线程池作为并发编程中的一大利器,其重要性不言而喻。今天,我将以对话的方式,带你一步步深入线程池的奥秘,从概述到功能点,再到背景和业务点,最后到底层原理和示例,让你对线程池有一个全新的认识。
58 12
|
1月前
|
PyTorch Shell API
Ascend Extension for PyTorch的源码解析
本文介绍了Ascend对PyTorch代码的适配过程,包括源码下载、编译步骤及常见问题,详细解析了torch-npu编译后的文件结构和三种实现昇腾NPU算子调用的方式:通过torch的register方式、定义算子方式和API重定向映射方式。这对于开发者理解和使用Ascend平台上的PyTorch具有重要指导意义。
|
15天前
|
安全 搜索推荐 数据挖掘
陪玩系统源码开发流程解析,成品陪玩系统源码的优点
我们自主开发的多客陪玩系统源码,整合了市面上主流陪玩APP功能,支持二次开发。该系统适用于线上游戏陪玩、语音视频聊天、心理咨询等场景,提供用户注册管理、陪玩者资料库、预约匹配、实时通讯、支付结算、安全隐私保护、客户服务及数据分析等功能,打造综合性社交平台。随着互联网技术发展,陪玩系统正成为游戏爱好者的新宠,改变游戏体验并带来新的商业模式。
|
2月前
|
存储 安全 Linux
Golang的GMP调度模型与源码解析
【11月更文挑战第11天】GMP 调度模型是 Go 语言运行时系统的核心部分,用于高效管理和调度大量协程(goroutine)。它通过少量的操作系统线程(M)和逻辑处理器(P)来调度大量的轻量级协程(G),从而实现高性能的并发处理。GMP 模型通过本地队列和全局队列来减少锁竞争,提高调度效率。在 Go 源码中,`runtime.h` 文件定义了关键数据结构,`schedule()` 和 `findrunnable()` 函数实现了核心调度逻辑。通过深入研究 GMP 模型,可以更好地理解 Go 语言的并发机制。
|
2月前
|
消息中间件 缓存 安全
Future与FutureTask源码解析,接口阻塞问题及解决方案
【11月更文挑战第5天】在Java开发中,多线程编程是提高系统并发性能和资源利用率的重要手段。然而,多线程编程也带来了诸如线程安全、死锁、接口阻塞等一系列复杂问题。本文将深度剖析多线程优化技巧、Future与FutureTask的源码、接口阻塞问题及解决方案,并通过具体业务场景和Java代码示例进行实战演示。
64 3
|
3月前
|
存储
让星星⭐月亮告诉你,HashMap的put方法源码解析及其中两种会触发扩容的场景(足够详尽,有问题欢迎指正~)
`HashMap`的`put`方法通过调用`putVal`实现,主要涉及两个场景下的扩容操作:1. 初始化时,链表数组的初始容量设为16,阈值设为12;2. 当存储的元素个数超过阈值时,链表数组的容量和阈值均翻倍。`putVal`方法处理键值对的插入,包括链表和红黑树的转换,确保高效的数据存取。
69 5

推荐镜像

更多