.\YOLO-World\yolo_world\datasets\transformers\mm_transforms.py
# 导入所需的库 import json import random from typing import Tuple import numpy as np from mmyolo.registry import TRANSFORMS # 注册 RandomLoadText 类为 TRANSFORMS 模块 @TRANSFORMS.register_module() class RandomLoadText: def __init__(self, text_path: str = None, prompt_format: str = '{}', num_neg_samples: Tuple[int, int] = (80, 80), max_num_samples: int = 80, padding_to_max: bool = False, padding_value: str = '') -> None: # 初始化 RandomLoadText 类的属性 self.prompt_format = prompt_format self.num_neg_samples = num_neg_samples self.max_num_samples = max_num_samples self.padding_to_max = padding_to_max self.padding_value = padding_value # 如果指定了 text_path,则读取对应文件内容 if text_path is not None: with open(text_path, 'r') as f: self.class_texts = json.load(f) # 注册 LoadText 类为 TRANSFORMS 模块 @TRANSFORMS.register_module() class LoadText: def __init__(self, text_path: str = None, prompt_format: str = '{}', multi_prompt_flag: str = '/') -> None: # 初始化 LoadText 类的属性 self.prompt_format = prompt_format self.multi_prompt_flag = multi_prompt_flag # 如果指定了 text_path,则读取对应文件内容 if text_path is not None: with open(text_path, 'r') as f: self.class_texts = json.load(f) # 定义 __call__ 方法,用于处理结果字典 def __call__(self, results: dict) -> dict: # 检查结果字典中是否包含 'texts' 键或者类属性中是否包含 'class_texts' assert 'texts' in results or hasattr(self, 'class_texts'), ( 'No texts found in results.') # 获取类属性中的 'class_texts' 或者结果字典中的 'texts' class_texts = results.get( 'texts', getattr(self, 'class_texts', None)) texts = [] # 遍历类别文本列表,处理每个类别文本 for idx, cls_caps in enumerate(class_texts): assert len(cls_caps) > 0 sel_cls_cap = cls_caps[0] sel_cls_cap = self.prompt_format.format(sel_cls_cap) texts.append(sel_cls_cap) # 将处理后的文本列表存入结果字典中的 'texts' 键 results['texts'] = texts return results
.\YOLO-World\yolo_world\datasets\transformers\__init__.py
# 导入腾讯公司的所有权声明 # 从当前目录下的 mm_transforms 模块中导入 RandomLoadText 和 LoadText 类 # 从当前目录下的 mm_mix_img_transforms 模块中导入 MultiModalMosaic、MultiModalMosaic9、YOLOv5MultiModalMixUp、YOLOXMultiModalMixUp 类 # 定义 __all__ 列表,包含需要导出的类名 __all__ = ['RandomLoadText', 'LoadText', 'MultiModalMosaic', 'MultiModalMosaic9', 'YOLOv5MultiModalMixUp', 'YOLOXMultiModalMixUp']
.\YOLO-World\yolo_world\datasets\utils.py
# 导入必要的库和模块 from typing import Sequence import torch from mmengine.dataset import COLLATE_FUNCTIONS # 注册自定义的数据集拼接函数 @COLLATE_FUNCTIONS.register_module() def yolow_collate(data_batch: Sequence, use_ms_training: bool = False) -> dict: """Rewrite collate_fn to get faster training speed. Args: data_batch (Sequence): Batch of data. use_ms_training (bool): Whether to use multi-scale training. """ # 初始化空列表用于存储数据 batch_imgs = [] batch_bboxes_labels = [] batch_masks = [] # 遍历数据批次 for i in range(len(data_batch)): datasamples = data_batch[i]['data_samples'] inputs = data_batch[i]['inputs'] batch_imgs.append(inputs) # 获取 ground truth 边界框和标签 gt_bboxes = datasamples.gt_instances.bboxes.tensor gt_labels = datasamples.gt_instances.labels # 如果数据中包含 masks,则转换为张量并添加到 batch_masks 列表中 if 'masks' in datasamples.gt_instances: masks = datasamples.gt_instances.masks.to_tensor( dtype=torch.bool, device=gt_bboxes.device) batch_masks.append(masks) # 创建 batch_idx 用于标识数据批次,拼接边界框和标签 batch_idx = gt_labels.new_full((len(gt_labels), 1), i) bboxes_labels = torch.cat((batch_idx, gt_labels[:, None], gt_bboxes), dim=1) batch_bboxes_labels.append(bboxes_labels) # 构建拼接后的结果字典 collated_results = { 'data_samples': { 'bboxes_labels': torch.cat(batch_bboxes_labels, 0) } } # 如果存在 masks 数据,则添加到结果字典中 if len(batch_masks) > 0: collated_results['data_samples']['masks'] = torch.cat(batch_masks, 0) # 根据是否使用多尺度训练,将输入数据添加到结果字典中 if use_ms_training: collated_results['inputs'] = batch_imgs else: collated_results['inputs'] = torch.stack(batch_imgs, 0) # 如果数据中包含文本信息,则添加到结果字典中 if hasattr(data_batch[0]['data_samples'], 'texts'): batch_texts = [meta['data_samples'].texts for meta in data_batch] collated_results['data_samples']['texts'] = batch_texts # 检查第一个数据批次中的'data_samples'是否具有'is_detection'属性 if hasattr(data_batch[0]['data_samples'], 'is_detection'): # 如果具有'data_samples'中的'is_detection'属性,则提取每个数据批次中'data_samples'的'is_detection'值 batch_detection = [meta['data_samples'].is_detection for meta in data_batch] # 将提取的'data_samples'中的'is_detection'值转换为torch张量,并存储在collated_results字典中 collated_results['data_samples']['is_detection'] = torch.tensor( batch_detection) # 返回整理后的结果字典 return collated_results
.\YOLO-World\yolo_world\datasets\yolov5_lvis.py
# 导入需要的模块 from mmdet.datasets import LVISV1Dataset # 导入自定义的数据集类 from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset from mmyolo.registry import DATASETS # 注册YOLOv5 LVIS数据集类,继承自BatchShapePolicyDataset和LVISV1Dataset @DATASETS.register_module() class YOLOv5LVISV1Dataset(BatchShapePolicyDataset, LVISV1Dataset): """Dataset for YOLOv5 LVIS Dataset. We only add `BatchShapePolicy` function compared with Objects365V1Dataset. See `mmyolo/datasets/utils.py#BatchShapePolicy` for details """ # 空的类定义,没有额外的方法或属性 pass
.\YOLO-World\yolo_world\datasets\yolov5_mixed_grounding.py
# 导入必要的模块 import os.path as osp from typing import List, Union # 导入自定义模块 from mmengine.fileio import get_local_path, join_path from mmengine.utils import is_abs from mmdet.datasets.coco import CocoDataset from mmyolo.registry import DATASETS from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset # 注册YOLOv5MixedGroundingDataset类为DATASETS @DATASETS.register_module() class YOLOv5MixedGroundingDataset(BatchShapePolicyDataset, CocoDataset): """Mixed grounding dataset.""" # 定义元信息 METAINFO = { 'classes': ('object',), 'palette': [(220, 20, 60)]} # 加载数据列表 def load_data_list(self) -> List[dict]: """Load annotations from an annotation file named as ``self.ann_file`` Returns: List[dict]: A list of annotation. """ # noqa: E501 # 使用get_local_path函数获取本地路径 with get_local_path( self.ann_file, backend_args=self.backend_args) as local_path: # 使用COCOAPI加载本地路径的数据 self.coco = self.COCOAPI(local_path) # 获取图像ID列表 img_ids = self.coco.get_img_ids() data_list = [] total_ann_ids = [] for img_id in img_ids: # 加载原始图像信息 raw_img_info = self.coco.load_imgs([img_id])[0] raw_img_info['img_id'] = img_id # 获取图像对应的注释ID列表 ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) raw_ann_info = self.coco.load_anns(ann_ids) total_ann_ids.extend(ann_ids) # 解析数据信息 parsed_data_info = self.parse_data_info({ 'raw_ann_info': raw_ann_info, 'raw_img_info': raw_img_info }) data_list.append(parsed_data_info) # 检查注释ID是否唯一 if self.ANN_ID_UNIQUE: assert len(set(total_ann_ids)) == len( total_ann_ids ), f"Annotation ids in '{self.ann_file}' are not unique!" # 删除self.coco对象 del self.coco # 返回数据列表 return data_list def filter_data(self) -> List[dict]: """Filter annotations according to filter_cfg. Returns: List[dict]: Filtered results. """ # 如果处于测试模式,则直接返回原始数据列表 if self.test_mode: return self.data_list # 如果没有设置过滤配置,则直接返回原始数据列表 if self.filter_cfg is None: return self.data_list # 获取过滤空标注和最小尺寸的配置参数 filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) min_size = self.filter_cfg.get('min_size', 0) # 获取包含标注的图片的 ID 集合 ids_with_ann = set(data_info['img_id'] for data_info in self.data_list) valid_data_infos = [] # 遍历数据列表,筛选符合条件的数据信息 for i, data_info in enumerate(self.data_list): img_id = data_info['img_id'] width = int(data_info['width']) height = int(data_info['height']) # 如果设置了过滤空标注并且当前图片没有标注,则跳过 if filter_empty_gt and img_id not in ids_with_ann: continue # 如果图片宽高中的最小值大于等于最小尺寸,则将该数据信息添加到有效数据列表中 if min(width, height) >= min_size: valid_data_infos.append(data_info) # 返回筛选后的有效数据信息列表 return valid_data_infos # 将 self.data_root 与 self.data_prefix 和 self.ann_file 连接起来 def _join_prefix(self): """Join ``self.data_root`` with ``self.data_prefix`` and ``self.ann_file``. """ # 如果 self.ann_file 不是绝对路径且 self.data_root 存在,则自动将注释文件路径与 self.root 连接起来 if self.ann_file and not is_abs(self.ann_file) and self.data_root: self.ann_file = join_path(self.data_root, self.ann_file) # 如果 self.data_prefix 中的路径值不是绝对路径,则自动将数据目录与 self.root 连接起来 for data_key, prefix in self.data_prefix.items(): if isinstance(prefix, (list, tuple)): abs_prefix = [] for p in prefix: if not is_abs(p) and self.data_root: abs_prefix.append(join_path(self.data_root, p)) else: abs_prefix.append(p) self.data_prefix[data_key] = abs_prefix elif isinstance(prefix, str): if not is_abs(prefix) and self.data_root: self.data_prefix[data_key] = join_path( self.data_root, prefix) else: self.data_prefix[data_key] = prefix else: raise TypeError('prefix should be a string, tuple or list,' f'but got {type(prefix)}')
.\YOLO-World\yolo_world\datasets\yolov5_obj365v1.py
# 导入需要的模块 from mmdet.datasets import Objects365V1Dataset # 导入自定义的数据集类 from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset from mmyolo.registry import DATASETS # 注册YOLOv5Objects365V1Dataset类到DATASETS模块 @DATASETS.register_module() class YOLOv5Objects365V1Dataset(BatchShapePolicyDataset, Objects365V1Dataset): """Dataset for YOLOv5 VOC Dataset. We only add `BatchShapePolicy` function compared with Objects365V1Dataset. See `mmyolo/datasets/utils.py#BatchShapePolicy` for details """ pass
.\YOLO-World\yolo_world\datasets\yolov5_obj365v2.py
# 导入 Objects365V2Dataset 类 from mmdet.datasets import Objects365V2Dataset # 导入 BatchShapePolicyDataset 类 from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset # 导入 DATASETS 注册表 from mmyolo.registry import DATASETS # 注册 YOLOv5Objects365V2Dataset 类到 DATASETS 注册表 @DATASETS.register_module() class YOLOv5Objects365V2Dataset(BatchShapePolicyDataset, Objects365V2Dataset): """Dataset for YOLOv5 VOC Dataset. We only add `BatchShapePolicy` function compared with Objects365V1Dataset. See `mmyolo/datasets/utils.py#BatchShapePolicy` for details """ # 空的类定义,继承自 BatchShapePolicyDataset 和 Objects365V2Dataset pass
.\YOLO-World\yolo_world\datasets\yolov5_v3det.py
# 导入所需的模块和函数 import copy import json import os.path as osp from typing import List from mmengine.fileio import get_local_path from mmdet.datasets.api_wrappers import COCO from mmdet.datasets import CocoDataset from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset from mmyolo.registry import DATASETS # 定义需要忽略的文件列表 v3det_ignore_list = [ 'a00013820/26_275_28143226914_ff3a247c53_c.jpg', 'n03815615/12_1489_32968099046_be38fa580e_c.jpg', 'n04550184/19_1480_2504784164_ffa3db8844_c.jpg', 'a00008703/2_363_3576131784_dfac6fc6ce_c.jpg', 'n02814533/28_2216_30224383848_a90697f1b3_c.jpg', 'n12026476/29_186_15091304754_5c219872f7_c.jpg', 'n01956764/12_2004_50133201066_72e0d9fea5_c.jpg', 'n03785016/14_2642_518053131_d07abcb5da_c.jpg', 'a00011156/33_250_4548479728_9ce5246596_c.jpg', 'a00009461/19_152_2792869324_db95bebc84_c.jpg', ] # 注册 V3DetDataset 类 @DATASETS.register_module() class V3DetDataset(CocoDataset): """Objects365 v1 dataset for detection.""" METAINFO = {'classes': 'classes', 'palette': None} COCOAPI = COCO # ann_id is unique in coco dataset. ANN_ID_UNIQUE = True # 注册 YOLOv5V3DetDataset 类,继承自 BatchShapePolicyDataset 和 V3DetDataset @DATASETS.register_module() class YOLOv5V3DetDataset(BatchShapePolicyDataset, V3DetDataset): """Dataset for YOLOv5 VOC Dataset. We only add `BatchShapePolicy` function compared with Objects365V1Dataset. See `mmyolo/datasets/utils.py#BatchShapePolicy` for details """ pass
.\YOLO-World\yolo_world\datasets\__init__.py
# 导入所需的模块和类 from .mm_dataset import ( MultiModalDataset, MultiModalMixedDataset) from .yolov5_obj365v1 import YOLOv5Objects365V1Dataset from .yolov5_obj365v2 import YOLOv5Objects365V2Dataset from .yolov5_mixed_grounding import YOLOv5MixedGroundingDataset from .utils import yolow_collate from .transformers import * # NOQA from .yolov5_v3det import YOLOv5V3DetDataset from .yolov5_lvis import YOLOv5LVISV1Dataset # 定义导出的模块和类列表 __all__ = [ 'MultiModalDataset', 'YOLOv5Objects365V1Dataset', 'YOLOv5Objects365V2Dataset', 'YOLOv5MixedGroundingDataset', 'YOLOv5V3DetDataset', 'yolow_collate', 'YOLOv5LVISV1Dataset', 'MultiModalMixedDataset', ]
.\YOLO-World\yolo_world\engine\optimizers\yolow_v5_optim_constructor.py
# 版权声明,版权归腾讯公司所有 import logging from typing import List, Optional, Union import torch import torch.nn as nn from torch.nn import GroupNorm, LayerNorm from mmengine.dist import get_world_size from mmengine.logging import print_log from mmengine.optim import OptimWrapper, DefaultOptimWrapperConstructor from mmengine.utils.dl_utils import mmcv_full_available from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm from mmyolo.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS) # 注册优化器包装器构造函数 @OPTIM_WRAPPER_CONSTRUCTORS.register_module() class YOLOWv5OptimizerConstructor(DefaultOptimWrapperConstructor): """YOLO World v5 constructor for optimizers.""" # 初始化函数,接受优化器包装器配置和参数配置 def __init__(self, optim_wrapper_cfg: dict, paramwise_cfg: Optional[dict] = None) -> None: # 调用父类的初始化函数 super().__init__(optim_wrapper_cfg, paramwise_cfg) # 从参数配置中弹出'base_total_batch_size',默认值为64 self.base_total_batch_size = self.paramwise_cfg.pop( 'base_total_batch_size', 64) # 定义一个方法,用于为模型创建优化器包装器 def __call__(self, model: nn.Module) -> OptimWrapper: # 如果模型有'module'属性,则将'module'属性赋值给model if hasattr(model, 'module'): model = model.module # 复制优化器包装器配置 optim_wrapper_cfg = self.optim_wrapper_cfg.copy() # 设置默认的优化器包装器类型为'OptimWrapper' optim_wrapper_cfg.setdefault('type', 'OptimWrapper') # 复制优化器配置 optimizer_cfg = self.optimizer_cfg.copy() # 遵循原始的yolov5实现 if 'batch_size_per_gpu' in optimizer_cfg: # 弹出'batch_size_per_gpu'键值对,并赋值给batch_size_per_gpu batch_size_per_gpu = optimizer_cfg.pop('batch_size_per_gpu') # 计算总批量大小 total_batch_size = get_world_size() * batch_size_per_gpu # 计算累积步数 accumulate = max( round(self.base_total_batch_size / total_batch_size), 1) # 计算缩放因子 scale_factor = total_batch_size * \ accumulate / self.base_total_batch_size # 如果缩放因子不等于1 if scale_factor != 1: # 获取优化器配置中的权重衰减值 weight_decay = optimizer_cfg.get('weight_decay', 0) # 根据缩放因子调整权重衰减值 weight_decay *= scale_factor optimizer_cfg['weight_decay'] = weight_decay # 打印调整后的权重衰减值 print_log(f'Scaled weight_decay to {weight_decay}', 'current') # 如果没有指定paramwise选项,则使用全局设置 if not self.paramwise_cfg: # 将模型的参数设置为优化器配置的参数 optimizer_cfg['params'] = model.parameters() # 构建优化器 optimizer = OPTIMIZERS.build(optimizer_cfg) else: # 递归设置参数的学习率和权重衰减 params: List = [] self.add_params(params, model) optimizer_cfg['params'] = params optimizer = OPTIMIZERS.build(optimizer_cfg) # 构建优化器包装器 optim_wrapper = OPTIM_WRAPPERS.build( optim_wrapper_cfg, default_args=dict(optimizer=optimizer)) # 返回优化器包装器 return optim_wrapper
yolo-world 源码解析(五)(2)https://developer.aliyun.com/article/1483890