yolo-world 源码解析(四)(3)

本文涉及的产品
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
云解析 DNS,旗舰版 1个月
简介: yolo-world 源码解析(四)

yolo-world 源码解析(四)(2)https://developer.aliyun.com/article/1483876

.\YOLO-World\tools\train.py

# 导入必要的库和模块
import argparse  # 用于解析命令行参数
import logging  # 用于记录日志
import os  # 用于操作系统相关功能
import os.path as osp  # 用于操作文件路径
from mmengine.config import Config, DictAction  # 导入Config和DictAction类
from mmengine.logging import print_log  # 导入print_log函数
from mmengine.runner import Runner  # 导入Runner类
from mmyolo.registry import RUNNERS  # 导入RUNNERS变量
from mmyolo.utils import is_metainfo_lower  # 导入is_metainfo_lower函数
# 解析命令行参数
def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')  # 创建参数解析器
    parser.add_argument('config', help='train config file path')  # 添加必需的参数
    parser.add_argument('--work-dir', help='the dir to save logs and models')  # 添加可选参数
    parser.add_argument(
        '--amp',
        action='store_true',
        default=False,
        help='enable automatic-mixed-precision training')  # 添加可选参数
    parser.add_argument(
        '--resume',
        nargs='?',
        type=str,
        const='auto',
        help='If specify checkpoint path, resume from it, while if not '
        'specify, try to auto resume from the latest checkpoint '
        'in the work directory.')  # 添加可选参数
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')  # 添加可选参数
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')  # 添加可选参数
    parser.add_argument('--local_rank', type=int, default=0)  # 添加可选参数
    args = parser.parse_args()  # 解析命令行参数
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)  # 设置环境变量LOCAL_RANK为args.local_rank的值
    return args  # 返回解析后的参数
def main():
    args = parse_args()  # 解析命令行参数并保存到args变量中
    # 加载配置文件
    cfg = Config.fromfile(args.config)  # 从配置文件路径args.config中加载配置信息
    # 用cfg.key的值替换${key}的占位符
    # 设置配置文件中的 launcher 为命令行参数中指定的 launcher
    cfg.launcher = args.launcher
    # 如果命令行参数中指定了 cfg_options,则将其合并到配置文件中
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    # 确定工作目录的优先级:CLI > 文件中的段 > 文件名
    if args.work_dir is not None:
        # 如果命令行参数中指定了 work_dir,则更新配置文件中的 work_dir
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # 如果配置文件中的 work_dir 为 None,则根据配置文件名设置默认的 work_dir
        if args.config.startswith('projects/'):
            config = args.config[len('projects/'):]
            config = config.replace('/configs/', '/')
            cfg.work_dir = osp.join('./work_dirs', osp.splitext(config)[0])
        else:
            cfg.work_dir = osp.join('./work_dirs',
                                    osp.splitext(osp.basename(args.config))[0])
    # 启用自动混合精度训练
    if args.amp is True:
        optim_wrapper = cfg.optim_wrapper.type
        if optim_wrapper == 'AmpOptimWrapper':
            print_log(
                'AMP training is already enabled in your config.',
                logger='current',
                level=logging.WARNING)
        else:
            assert optim_wrapper == 'OptimWrapper', (
                '`--amp` is only supported when the optimizer wrapper type is '
                f'`OptimWrapper` but got {optim_wrapper}.')
            cfg.optim_wrapper.type = 'AmpOptimWrapper'
            cfg.optim_wrapper.loss_scale = 'dynamic'
    # 确定恢复训练的优先级:resume from > auto_resume
    if args.resume == 'auto':
        cfg.resume = True
        cfg.load_from = None
    elif args.resume is not None:
        cfg.resume = True
        cfg.load_from = args.resume
    # 确定自定义元信息字段是否全部为小写
    is_metainfo_lower(cfg)
    # 从配置文件构建 runner
    # 如果配置中没有指定 'runner_type'
    if 'runner_type' not in cfg:
        # 构建默认的运行器
        runner = Runner.from_cfg(cfg)
    else:
        # 从注册表中构建定制的运行器
        # 如果配置中设置了 'runner_type'
        runner = RUNNERS.build(cfg)
    # 开始训练
    runner.train()
# 如果当前脚本被直接执行,则调用主函数
if __name__ == '__main__':
    main()

.\YOLO-World\yolo_world\datasets\mm_dataset.py

# 导入所需的模块和类
import copy
import json
import logging
from typing import Callable, List, Union
from mmengine.logging import print_log
from mmengine.dataset.base_dataset import (
        BaseDataset, Compose, force_full_init)
from mmyolo.registry import DATASETS
# 注册MultiModalDataset类到DATASETS
@DATASETS.register_module()
class MultiModalDataset:
    """Multi-modal dataset."""
    def __init__(self,
                 dataset: Union[BaseDataset, dict],
                 class_text_path: str = None,
                 test_mode: bool = True,
                 pipeline: List[Union[dict, Callable]] = [],
                 lazy_init: bool = False) -> None:
        # 初始化dataset属性
        self.dataset: BaseDataset
        if isinstance(dataset, dict):
            self.dataset = DATASETS.build(dataset)
        elif isinstance(dataset, BaseDataset):
            self.dataset = dataset
        else:
            raise TypeError(
                'dataset must be a dict or a BaseDataset, '
                f'but got {dataset}')
        # 加载类别文本文件
        if class_text_path is not None:
            self.class_texts = json.load(open(class_text_path, 'r'))
            # ori_classes = self.dataset.metainfo['classes']
            # assert len(ori_classes) == len(self.class_texts), \
            #     ('The number of classes in the dataset and the class text'
            #      'file must be the same.')
        else:
            self.class_texts = None
        # 设置测试模式
        self.test_mode = test_mode
        # 获取数据集的元信息
        self._metainfo = self.dataset.metainfo
        # 初始化数据处理pipeline
        self.pipeline = Compose(pipeline)
        # 标记是否已完全初始化
        self._fully_initialized = False
        # 如果不是延迟初始化,则进行完全初始化
        if not lazy_init:
            self.full_init()
    @property
    def metainfo(self) -> dict:
        # 返回元信息的深拷贝
        return copy.deepcopy(self._metainfo)
    def full_init(self) -> None:
        """``full_init`` dataset."""
        # 如果已经完全初始化,则直接返回
        if self._fully_initialized:
            return
        # 对数据集进行完全初始化
        self.dataset.full_init()
        self._ori_len = len(self.dataset)
        self._fully_initialized = True
    @force_full_init
    # 根据索引获取数据信息,返回一个字典
    def get_data_info(self, idx: int) -> dict:
        """Get annotation by index."""
        # 通过数据集对象获取指定索引的数据信息
        data_info = self.dataset.get_data_info(idx)
        # 如果类别文本不为空,则将其添加到数据信息字典中
        if self.class_texts is not None:
            data_info.update({'texts': self.class_texts})
        return data_info
    # 根据索引获取数据
    def __getitem__(self, idx):
        # 如果数据集未完全初始化,则打印警告信息并手动调用`full_init`方法以加快速度
        if not self._fully_initialized:
            print_log(
                'Please call `full_init` method manually to '
                'accelerate the speed.',
                logger='current',
                level=logging.WARNING)
            self.full_init()
        # 获取数据信息
        data_info = self.get_data_info(idx)
        # 如果数据集具有'test_mode'属性且不为测试模式,则将数据集信息添加到数据信息字典中
        if hasattr(self.dataset, 'test_mode') and not self.dataset.test_mode:
            data_info['dataset'] = self
        # 如果不是测试模式,则将数据集信息添加到数据信息字典中
        elif not self.test_mode:
            data_info['dataset'] = self
        # 返回经过管道处理后的数据信息
        return self.pipeline(data_info)
    # 返回数据集的长度
    @force_full_init
    def __len__(self) -> int:
        return self._ori_len
# 注册 MultiModalMixedDataset 类到 DATASETS 模块
@DATASETS.register_module()
class MultiModalMixedDataset(MultiModalDataset):
    """Multi-modal Mixed dataset.
    mix "detection dataset" and "caption dataset"
    Args:
        dataset_type (str): dataset type, 'detection' or 'caption'
    """
    # 初始化方法,接受多种参数,包括 dataset、class_text_path、dataset_type、test_mode、pipeline 和 lazy_init
    def __init__(self,
                 dataset: Union[BaseDataset, dict],
                 class_text_path: str = None,
                 dataset_type: str = 'detection',
                 test_mode: bool = True,
                 pipeline: List[Union[dict, Callable]] = [],
                 lazy_init: bool = False) -> None:
        # 设置 dataset_type 属性
        self.dataset_type = dataset_type
        # 调用父类的初始化方法
        super().__init__(dataset,
                         class_text_path,
                         test_mode,
                         pipeline,
                         lazy_init)
    # 强制完全初始化装饰器,用于 get_data_info 方法
    @force_full_init
    def get_data_info(self, idx: int) -> dict:
        """Get annotation by index."""
        # 调用 dataset 的 get_data_info 方法获取数据信息
        data_info = self.dataset.get_data_info(idx)
        # 如果 class_texts 不为空,则更新 data_info 中的 'texts' 字段
        if self.class_texts is not None:
            data_info.update({'texts': self.class_texts})
        # 根据 dataset_type 设置 data_info 中的 'is_detection' 字段
        data_info['is_detection'] = 1 \
            if self.dataset_type == 'detection' else 0
        return data_info

yolo-world 源码解析(四)(4)https://developer.aliyun.com/article/1483878

相关文章
|
2月前
|
监控 Java 应用服务中间件
高级java面试---spring.factories文件的解析源码API机制
【11月更文挑战第20天】Spring Boot是一个用于快速构建基于Spring框架的应用程序的开源框架。它通过自动配置、起步依赖和内嵌服务器等特性,极大地简化了Spring应用的开发和部署过程。本文将深入探讨Spring Boot的背景历史、业务场景、功能点以及底层原理,并通过Java代码手写模拟Spring Boot的启动过程,特别是spring.factories文件的解析源码API机制。
87 2
|
13天前
|
存储 设计模式 算法
【23种设计模式·全精解析 | 行为型模式篇】11种行为型模式的结构概述、案例实现、优缺点、扩展对比、使用场景、源码解析
行为型模式用于描述程序在运行时复杂的流程控制,即描述多个类或对象之间怎样相互协作共同完成单个对象都无法单独完成的任务,它涉及算法与对象间职责的分配。行为型模式分为类行为模式和对象行为模式,前者采用继承机制来在类间分派行为,后者采用组合或聚合在对象间分配行为。由于组合关系或聚合关系比继承关系耦合度低,满足“合成复用原则”,所以对象行为模式比类行为模式具有更大的灵活性。 行为型模式分为: • 模板方法模式 • 策略模式 • 命令模式 • 职责链模式 • 状态模式 • 观察者模式 • 中介者模式 • 迭代器模式 • 访问者模式 • 备忘录模式 • 解释器模式
【23种设计模式·全精解析 | 行为型模式篇】11种行为型模式的结构概述、案例实现、优缺点、扩展对比、使用场景、源码解析
|
13天前
|
设计模式 存储 安全
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
结构型模式描述如何将类或对象按某种布局组成更大的结构。它分为类结构型模式和对象结构型模式,前者采用继承机制来组织接口和类,后者釆用组合或聚合来组合对象。由于组合关系或聚合关系比继承关系耦合度低,满足“合成复用原则”,所以对象结构型模式比类结构型模式具有更大的灵活性。 结构型模式分为以下 7 种: • 代理模式 • 适配器模式 • 装饰者模式 • 桥接模式 • 外观模式 • 组合模式 • 享元模式
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
|
13天前
|
设计模式 存储 安全
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
创建型模式的主要关注点是“怎样创建对象?”,它的主要特点是"将对象的创建与使用分离”。这样可以降低系统的耦合度,使用者不需要关注对象的创建细节。创建型模式分为5种:单例模式、工厂方法模式抽象工厂式、原型模式、建造者模式。
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
|
2月前
|
缓存 监控 Java
Java线程池提交任务流程底层源码与源码解析
【11月更文挑战第30天】嘿,各位技术爱好者们,今天咱们来聊聊Java线程池提交任务的底层源码与源码解析。作为一个资深的Java开发者,我相信你一定对线程池并不陌生。线程池作为并发编程中的一大利器,其重要性不言而喻。今天,我将以对话的方式,带你一步步深入线程池的奥秘,从概述到功能点,再到背景和业务点,最后到底层原理和示例,让你对线程池有一个全新的认识。
57 12
|
1月前
|
PyTorch Shell API
Ascend Extension for PyTorch的源码解析
本文介绍了Ascend对PyTorch代码的适配过程,包括源码下载、编译步骤及常见问题,详细解析了torch-npu编译后的文件结构和三种实现昇腾NPU算子调用的方式:通过torch的register方式、定义算子方式和API重定向映射方式。这对于开发者理解和使用Ascend平台上的PyTorch具有重要指导意义。
|
14天前
|
安全 搜索推荐 数据挖掘
陪玩系统源码开发流程解析,成品陪玩系统源码的优点
我们自主开发的多客陪玩系统源码,整合了市面上主流陪玩APP功能,支持二次开发。该系统适用于线上游戏陪玩、语音视频聊天、心理咨询等场景,提供用户注册管理、陪玩者资料库、预约匹配、实时通讯、支付结算、安全隐私保护、客户服务及数据分析等功能,打造综合性社交平台。随着互联网技术发展,陪玩系统正成为游戏爱好者的新宠,改变游戏体验并带来新的商业模式。
|
2月前
|
存储 安全 Linux
Golang的GMP调度模型与源码解析
【11月更文挑战第11天】GMP 调度模型是 Go 语言运行时系统的核心部分,用于高效管理和调度大量协程(goroutine)。它通过少量的操作系统线程(M)和逻辑处理器(P)来调度大量的轻量级协程(G),从而实现高性能的并发处理。GMP 模型通过本地队列和全局队列来减少锁竞争,提高调度效率。在 Go 源码中,`runtime.h` 文件定义了关键数据结构,`schedule()` 和 `findrunnable()` 函数实现了核心调度逻辑。通过深入研究 GMP 模型,可以更好地理解 Go 语言的并发机制。
|
2月前
|
消息中间件 缓存 安全
Future与FutureTask源码解析,接口阻塞问题及解决方案
【11月更文挑战第5天】在Java开发中,多线程编程是提高系统并发性能和资源利用率的重要手段。然而,多线程编程也带来了诸如线程安全、死锁、接口阻塞等一系列复杂问题。本文将深度剖析多线程优化技巧、Future与FutureTask的源码、接口阻塞问题及解决方案,并通过具体业务场景和Java代码示例进行实战演示。
63 3
|
3月前
|
存储
让星星⭐月亮告诉你,HashMap的put方法源码解析及其中两种会触发扩容的场景(足够详尽,有问题欢迎指正~)
`HashMap`的`put`方法通过调用`putVal`实现,主要涉及两个场景下的扩容操作:1. 初始化时,链表数组的初始容量设为16,阈值设为12;2. 当存储的元素个数超过阈值时,链表数组的容量和阈值均翻倍。`putVal`方法处理键值对的插入,包括链表和红黑树的转换,确保高效的数据存取。
68 5