yolo-world 源码解析(四)(2)https://developer.aliyun.com/article/1483876
.\YOLO-World\tools\train.py
# 导入必要的库和模块 import argparse # 用于解析命令行参数 import logging # 用于记录日志 import os # 用于操作系统相关功能 import os.path as osp # 用于操作文件路径 from mmengine.config import Config, DictAction # 导入Config和DictAction类 from mmengine.logging import print_log # 导入print_log函数 from mmengine.runner import Runner # 导入Runner类 from mmyolo.registry import RUNNERS # 导入RUNNERS变量 from mmyolo.utils import is_metainfo_lower # 导入is_metainfo_lower函数 # 解析命令行参数 def parse_args(): parser = argparse.ArgumentParser(description='Train a detector') # 创建参数解析器 parser.add_argument('config', help='train config file path') # 添加必需的参数 parser.add_argument('--work-dir', help='the dir to save logs and models') # 添加可选参数 parser.add_argument( '--amp', action='store_true', default=False, help='enable automatic-mixed-precision training') # 添加可选参数 parser.add_argument( '--resume', nargs='?', type=str, const='auto', help='If specify checkpoint path, resume from it, while if not ' 'specify, try to auto resume from the latest checkpoint ' 'in the work directory.') # 添加可选参数 parser.add_argument( '--cfg-options', nargs='+', action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') # 添加可选参数 parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') # 添加可选参数 parser.add_argument('--local_rank', type=int, default=0) # 添加可选参数 args = parser.parse_args() # 解析命令行参数 if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) # 设置环境变量LOCAL_RANK为args.local_rank的值 return args # 返回解析后的参数 def main(): args = parse_args() # 解析命令行参数并保存到args变量中 # 加载配置文件 cfg = Config.fromfile(args.config) # 从配置文件路径args.config中加载配置信息 # 用cfg.key的值替换${key}的占位符 # 设置配置文件中的 launcher 为命令行参数中指定的 launcher cfg.launcher = args.launcher # 如果命令行参数中指定了 cfg_options,则将其合并到配置文件中 if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) # 确定工作目录的优先级:CLI > 文件中的段 > 文件名 if args.work_dir is not None: # 如果命令行参数中指定了 work_dir,则更新配置文件中的 work_dir cfg.work_dir = args.work_dir elif cfg.get('work_dir', None) is None: # 如果配置文件中的 work_dir 为 None,则根据配置文件名设置默认的 work_dir if args.config.startswith('projects/'): config = args.config[len('projects/'):] config = config.replace('/configs/', '/') cfg.work_dir = osp.join('./work_dirs', osp.splitext(config)[0]) else: cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) # 启用自动混合精度训练 if args.amp is True: optim_wrapper = cfg.optim_wrapper.type if optim_wrapper == 'AmpOptimWrapper': print_log( 'AMP training is already enabled in your config.', logger='current', level=logging.WARNING) else: assert optim_wrapper == 'OptimWrapper', ( '`--amp` is only supported when the optimizer wrapper type is ' f'`OptimWrapper` but got {optim_wrapper}.') cfg.optim_wrapper.type = 'AmpOptimWrapper' cfg.optim_wrapper.loss_scale = 'dynamic' # 确定恢复训练的优先级:resume from > auto_resume if args.resume == 'auto': cfg.resume = True cfg.load_from = None elif args.resume is not None: cfg.resume = True cfg.load_from = args.resume # 确定自定义元信息字段是否全部为小写 is_metainfo_lower(cfg) # 从配置文件构建 runner # 如果配置中没有指定 'runner_type' if 'runner_type' not in cfg: # 构建默认的运行器 runner = Runner.from_cfg(cfg) else: # 从注册表中构建定制的运行器 # 如果配置中设置了 'runner_type' runner = RUNNERS.build(cfg) # 开始训练 runner.train() # 如果当前脚本被直接执行,则调用主函数 if __name__ == '__main__': main()
.\YOLO-World\yolo_world\datasets\mm_dataset.py
# 导入所需的模块和类 import copy import json import logging from typing import Callable, List, Union from mmengine.logging import print_log from mmengine.dataset.base_dataset import ( BaseDataset, Compose, force_full_init) from mmyolo.registry import DATASETS # 注册MultiModalDataset类到DATASETS @DATASETS.register_module() class MultiModalDataset: """Multi-modal dataset.""" def __init__(self, dataset: Union[BaseDataset, dict], class_text_path: str = None, test_mode: bool = True, pipeline: List[Union[dict, Callable]] = [], lazy_init: bool = False) -> None: # 初始化dataset属性 self.dataset: BaseDataset if isinstance(dataset, dict): self.dataset = DATASETS.build(dataset) elif isinstance(dataset, BaseDataset): self.dataset = dataset else: raise TypeError( 'dataset must be a dict or a BaseDataset, ' f'but got {dataset}') # 加载类别文本文件 if class_text_path is not None: self.class_texts = json.load(open(class_text_path, 'r')) # ori_classes = self.dataset.metainfo['classes'] # assert len(ori_classes) == len(self.class_texts), \ # ('The number of classes in the dataset and the class text' # 'file must be the same.') else: self.class_texts = None # 设置测试模式 self.test_mode = test_mode # 获取数据集的元信息 self._metainfo = self.dataset.metainfo # 初始化数据处理pipeline self.pipeline = Compose(pipeline) # 标记是否已完全初始化 self._fully_initialized = False # 如果不是延迟初始化,则进行完全初始化 if not lazy_init: self.full_init() @property def metainfo(self) -> dict: # 返回元信息的深拷贝 return copy.deepcopy(self._metainfo) def full_init(self) -> None: """``full_init`` dataset.""" # 如果已经完全初始化,则直接返回 if self._fully_initialized: return # 对数据集进行完全初始化 self.dataset.full_init() self._ori_len = len(self.dataset) self._fully_initialized = True @force_full_init # 根据索引获取数据信息,返回一个字典 def get_data_info(self, idx: int) -> dict: """Get annotation by index.""" # 通过数据集对象获取指定索引的数据信息 data_info = self.dataset.get_data_info(idx) # 如果类别文本不为空,则将其添加到数据信息字典中 if self.class_texts is not None: data_info.update({'texts': self.class_texts}) return data_info # 根据索引获取数据 def __getitem__(self, idx): # 如果数据集未完全初始化,则打印警告信息并手动调用`full_init`方法以加快速度 if not self._fully_initialized: print_log( 'Please call `full_init` method manually to ' 'accelerate the speed.', logger='current', level=logging.WARNING) self.full_init() # 获取数据信息 data_info = self.get_data_info(idx) # 如果数据集具有'test_mode'属性且不为测试模式,则将数据集信息添加到数据信息字典中 if hasattr(self.dataset, 'test_mode') and not self.dataset.test_mode: data_info['dataset'] = self # 如果不是测试模式,则将数据集信息添加到数据信息字典中 elif not self.test_mode: data_info['dataset'] = self # 返回经过管道处理后的数据信息 return self.pipeline(data_info) # 返回数据集的长度 @force_full_init def __len__(self) -> int: return self._ori_len # 注册 MultiModalMixedDataset 类到 DATASETS 模块 @DATASETS.register_module() class MultiModalMixedDataset(MultiModalDataset): """Multi-modal Mixed dataset. mix "detection dataset" and "caption dataset" Args: dataset_type (str): dataset type, 'detection' or 'caption' """ # 初始化方法,接受多种参数,包括 dataset、class_text_path、dataset_type、test_mode、pipeline 和 lazy_init def __init__(self, dataset: Union[BaseDataset, dict], class_text_path: str = None, dataset_type: str = 'detection', test_mode: bool = True, pipeline: List[Union[dict, Callable]] = [], lazy_init: bool = False) -> None: # 设置 dataset_type 属性 self.dataset_type = dataset_type # 调用父类的初始化方法 super().__init__(dataset, class_text_path, test_mode, pipeline, lazy_init) # 强制完全初始化装饰器,用于 get_data_info 方法 @force_full_init def get_data_info(self, idx: int) -> dict: """Get annotation by index.""" # 调用 dataset 的 get_data_info 方法获取数据信息 data_info = self.dataset.get_data_info(idx) # 如果 class_texts 不为空,则更新 data_info 中的 'texts' 字段 if self.class_texts is not None: data_info.update({'texts': self.class_texts}) # 根据 dataset_type 设置 data_info 中的 'is_detection' 字段 data_info['is_detection'] = 1 \ if self.dataset_type == 'detection' else 0 return data_info
yolo-world 源码解析(四)(4)https://developer.aliyun.com/article/1483878