yolo-world 源码解析(四)(3)

简介: yolo-world 源码解析(四)

yolo-world 源码解析(四)(2)https://developer.aliyun.com/article/1483876

.\YOLO-World\tools\train.py

# 导入必要的库和模块
import argparse  # 用于解析命令行参数
import logging  # 用于记录日志
import os  # 用于操作系统相关功能
import os.path as osp  # 用于操作文件路径
from mmengine.config import Config, DictAction  # 导入Config和DictAction类
from mmengine.logging import print_log  # 导入print_log函数
from mmengine.runner import Runner  # 导入Runner类
from mmyolo.registry import RUNNERS  # 导入RUNNERS变量
from mmyolo.utils import is_metainfo_lower  # 导入is_metainfo_lower函数
# 解析命令行参数
def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')  # 创建参数解析器
    parser.add_argument('config', help='train config file path')  # 添加必需的参数
    parser.add_argument('--work-dir', help='the dir to save logs and models')  # 添加可选参数
    parser.add_argument(
        '--amp',
        action='store_true',
        default=False,
        help='enable automatic-mixed-precision training')  # 添加可选参数
    parser.add_argument(
        '--resume',
        nargs='?',
        type=str,
        const='auto',
        help='If specify checkpoint path, resume from it, while if not '
        'specify, try to auto resume from the latest checkpoint '
        'in the work directory.')  # 添加可选参数
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')  # 添加可选参数
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')  # 添加可选参数
    parser.add_argument('--local_rank', type=int, default=0)  # 添加可选参数
    args = parser.parse_args()  # 解析命令行参数
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)  # 设置环境变量LOCAL_RANK为args.local_rank的值
    return args  # 返回解析后的参数
def main():
    args = parse_args()  # 解析命令行参数并保存到args变量中
    # 加载配置文件
    cfg = Config.fromfile(args.config)  # 从配置文件路径args.config中加载配置信息
    # 用cfg.key的值替换${key}的占位符
    # 设置配置文件中的 launcher 为命令行参数中指定的 launcher
    cfg.launcher = args.launcher
    # 如果命令行参数中指定了 cfg_options,则将其合并到配置文件中
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    # 确定工作目录的优先级:CLI > 文件中的段 > 文件名
    if args.work_dir is not None:
        # 如果命令行参数中指定了 work_dir,则更新配置文件中的 work_dir
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # 如果配置文件中的 work_dir 为 None,则根据配置文件名设置默认的 work_dir
        if args.config.startswith('projects/'):
            config = args.config[len('projects/'):]
            config = config.replace('/configs/', '/')
            cfg.work_dir = osp.join('./work_dirs', osp.splitext(config)[0])
        else:
            cfg.work_dir = osp.join('./work_dirs',
                                    osp.splitext(osp.basename(args.config))[0])
    # 启用自动混合精度训练
    if args.amp is True:
        optim_wrapper = cfg.optim_wrapper.type
        if optim_wrapper == 'AmpOptimWrapper':
            print_log(
                'AMP training is already enabled in your config.',
                logger='current',
                level=logging.WARNING)
        else:
            assert optim_wrapper == 'OptimWrapper', (
                '`--amp` is only supported when the optimizer wrapper type is '
                f'`OptimWrapper` but got {optim_wrapper}.')
            cfg.optim_wrapper.type = 'AmpOptimWrapper'
            cfg.optim_wrapper.loss_scale = 'dynamic'
    # 确定恢复训练的优先级:resume from > auto_resume
    if args.resume == 'auto':
        cfg.resume = True
        cfg.load_from = None
    elif args.resume is not None:
        cfg.resume = True
        cfg.load_from = args.resume
    # 确定自定义元信息字段是否全部为小写
    is_metainfo_lower(cfg)
    # 从配置文件构建 runner
    # 如果配置中没有指定 'runner_type'
    if 'runner_type' not in cfg:
        # 构建默认的运行器
        runner = Runner.from_cfg(cfg)
    else:
        # 从注册表中构建定制的运行器
        # 如果配置中设置了 'runner_type'
        runner = RUNNERS.build(cfg)
    # 开始训练
    runner.train()
# 如果当前脚本被直接执行,则调用主函数
if __name__ == '__main__':
    main()

.\YOLO-World\yolo_world\datasets\mm_dataset.py

# 导入所需的模块和类
import copy
import json
import logging
from typing import Callable, List, Union
from mmengine.logging import print_log
from mmengine.dataset.base_dataset import (
        BaseDataset, Compose, force_full_init)
from mmyolo.registry import DATASETS
# 注册MultiModalDataset类到DATASETS
@DATASETS.register_module()
class MultiModalDataset:
    """Multi-modal dataset."""
    def __init__(self,
                 dataset: Union[BaseDataset, dict],
                 class_text_path: str = None,
                 test_mode: bool = True,
                 pipeline: List[Union[dict, Callable]] = [],
                 lazy_init: bool = False) -> None:
        # 初始化dataset属性
        self.dataset: BaseDataset
        if isinstance(dataset, dict):
            self.dataset = DATASETS.build(dataset)
        elif isinstance(dataset, BaseDataset):
            self.dataset = dataset
        else:
            raise TypeError(
                'dataset must be a dict or a BaseDataset, '
                f'but got {dataset}')
        # 加载类别文本文件
        if class_text_path is not None:
            self.class_texts = json.load(open(class_text_path, 'r'))
            # ori_classes = self.dataset.metainfo['classes']
            # assert len(ori_classes) == len(self.class_texts), \
            #     ('The number of classes in the dataset and the class text'
            #      'file must be the same.')
        else:
            self.class_texts = None
        # 设置测试模式
        self.test_mode = test_mode
        # 获取数据集的元信息
        self._metainfo = self.dataset.metainfo
        # 初始化数据处理pipeline
        self.pipeline = Compose(pipeline)
        # 标记是否已完全初始化
        self._fully_initialized = False
        # 如果不是延迟初始化,则进行完全初始化
        if not lazy_init:
            self.full_init()
    @property
    def metainfo(self) -> dict:
        # 返回元信息的深拷贝
        return copy.deepcopy(self._metainfo)
    def full_init(self) -> None:
        """``full_init`` dataset."""
        # 如果已经完全初始化,则直接返回
        if self._fully_initialized:
            return
        # 对数据集进行完全初始化
        self.dataset.full_init()
        self._ori_len = len(self.dataset)
        self._fully_initialized = True
    @force_full_init
    # 根据索引获取数据信息,返回一个字典
    def get_data_info(self, idx: int) -> dict:
        """Get annotation by index."""
        # 通过数据集对象获取指定索引的数据信息
        data_info = self.dataset.get_data_info(idx)
        # 如果类别文本不为空,则将其添加到数据信息字典中
        if self.class_texts is not None:
            data_info.update({'texts': self.class_texts})
        return data_info
    # 根据索引获取数据
    def __getitem__(self, idx):
        # 如果数据集未完全初始化,则打印警告信息并手动调用`full_init`方法以加快速度
        if not self._fully_initialized:
            print_log(
                'Please call `full_init` method manually to '
                'accelerate the speed.',
                logger='current',
                level=logging.WARNING)
            self.full_init()
        # 获取数据信息
        data_info = self.get_data_info(idx)
        # 如果数据集具有'test_mode'属性且不为测试模式,则将数据集信息添加到数据信息字典中
        if hasattr(self.dataset, 'test_mode') and not self.dataset.test_mode:
            data_info['dataset'] = self
        # 如果不是测试模式,则将数据集信息添加到数据信息字典中
        elif not self.test_mode:
            data_info['dataset'] = self
        # 返回经过管道处理后的数据信息
        return self.pipeline(data_info)
    # 返回数据集的长度
    @force_full_init
    def __len__(self) -> int:
        return self._ori_len
# 注册 MultiModalMixedDataset 类到 DATASETS 模块
@DATASETS.register_module()
class MultiModalMixedDataset(MultiModalDataset):
    """Multi-modal Mixed dataset.
    mix "detection dataset" and "caption dataset"
    Args:
        dataset_type (str): dataset type, 'detection' or 'caption'
    """
    # 初始化方法,接受多种参数,包括 dataset、class_text_path、dataset_type、test_mode、pipeline 和 lazy_init
    def __init__(self,
                 dataset: Union[BaseDataset, dict],
                 class_text_path: str = None,
                 dataset_type: str = 'detection',
                 test_mode: bool = True,
                 pipeline: List[Union[dict, Callable]] = [],
                 lazy_init: bool = False) -> None:
        # 设置 dataset_type 属性
        self.dataset_type = dataset_type
        # 调用父类的初始化方法
        super().__init__(dataset,
                         class_text_path,
                         test_mode,
                         pipeline,
                         lazy_init)
    # 强制完全初始化装饰器,用于 get_data_info 方法
    @force_full_init
    def get_data_info(self, idx: int) -> dict:
        """Get annotation by index."""
        # 调用 dataset 的 get_data_info 方法获取数据信息
        data_info = self.dataset.get_data_info(idx)
        # 如果 class_texts 不为空,则更新 data_info 中的 'texts' 字段
        if self.class_texts is not None:
            data_info.update({'texts': self.class_texts})
        # 根据 dataset_type 设置 data_info 中的 'is_detection' 字段
        data_info['is_detection'] = 1 \
            if self.dataset_type == 'detection' else 0
        return data_info

yolo-world 源码解析(四)(4)https://developer.aliyun.com/article/1483878

相关文章
|
3天前
|
Linux 网络安全 Windows
网络安全笔记-day8,DHCP部署_dhcp搭建部署,源码解析
网络安全笔记-day8,DHCP部署_dhcp搭建部署,源码解析
|
4天前
HuggingFace Tranformers 源码解析(4)
HuggingFace Tranformers 源码解析
6 0
|
4天前
HuggingFace Tranformers 源码解析(3)
HuggingFace Tranformers 源码解析
7 0
|
4天前
|
开发工具 git
HuggingFace Tranformers 源码解析(2)
HuggingFace Tranformers 源码解析
7 0
|
4天前
|
并行计算
HuggingFace Tranformers 源码解析(1)
HuggingFace Tranformers 源码解析
9 0
|
5天前
PandasTA 源码解析(二十三)
PandasTA 源码解析(二十三)
43 0
|
5天前
PandasTA 源码解析(二十二)(3)
PandasTA 源码解析(二十二)
35 0
|
5天前
PandasTA 源码解析(二十二)(2)
PandasTA 源码解析(二十二)
42 2
|
5天前
PandasTA 源码解析(二十二)(1)
PandasTA 源码解析(二十二)
33 0
|
5天前
PandasTA 源码解析(二十一)(4)
PandasTA 源码解析(二十一)
24 1

推荐镜像

更多