yolo-world 源码解析(五)(1)

本文涉及的产品
云解析 DNS,旗舰版 1个月
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
简介: yolo-world 源码解析(五)

.\YOLO-World\yolo_world\datasets\transformers\mm_transforms.py

# 导入所需的库
import json
import random
from typing import Tuple
import numpy as np
from mmyolo.registry import TRANSFORMS
# 注册 RandomLoadText 类为 TRANSFORMS 模块
@TRANSFORMS.register_module()
class RandomLoadText:
    def __init__(self,
                 text_path: str = None,
                 prompt_format: str = '{}',
                 num_neg_samples: Tuple[int, int] = (80, 80),
                 max_num_samples: int = 80,
                 padding_to_max: bool = False,
                 padding_value: str = '') -> None:
        # 初始化 RandomLoadText 类的属性
        self.prompt_format = prompt_format
        self.num_neg_samples = num_neg_samples
        self.max_num_samples = max_num_samples
        self.padding_to_max = padding_to_max
        self.padding_value = padding_value
        # 如果指定了 text_path,则读取对应文件内容
        if text_path is not None:
            with open(text_path, 'r') as f:
                self.class_texts = json.load(f)
# 注册 LoadText 类为 TRANSFORMS 模块
@TRANSFORMS.register_module()
class LoadText:
    def __init__(self,
                 text_path: str = None,
                 prompt_format: str = '{}',
                 multi_prompt_flag: str = '/') -> None:
        # 初始化 LoadText 类的属性
        self.prompt_format = prompt_format
        self.multi_prompt_flag = multi_prompt_flag
        # 如果指定了 text_path,则读取对应文件内容
        if text_path is not None:
            with open(text_path, 'r') as f:
                self.class_texts = json.load(f)
    # 定义 __call__ 方法,用于处理结果字典
    def __call__(self, results: dict) -> dict:
        # 检查结果字典中是否包含 'texts' 键或者类属性中是否包含 'class_texts'
        assert 'texts' in results or hasattr(self, 'class_texts'), (
            'No texts found in results.')
        # 获取类属性中的 'class_texts' 或者结果字典中的 'texts'
        class_texts = results.get(
            'texts',
            getattr(self, 'class_texts', None))
        texts = []
        # 遍历类别文本列表,处理每个类别文本
        for idx, cls_caps in enumerate(class_texts):
            assert len(cls_caps) > 0
            sel_cls_cap = cls_caps[0]
            sel_cls_cap = self.prompt_format.format(sel_cls_cap)
            texts.append(sel_cls_cap)
        # 将处理后的文本列表存入结果字典中的 'texts' 键
        results['texts'] = texts
        return results

.\YOLO-World\yolo_world\datasets\transformers\__init__.py

# 导入腾讯公司的所有权声明
# 从当前目录下的 mm_transforms 模块中导入 RandomLoadText 和 LoadText 类
# 从当前目录下的 mm_mix_img_transforms 模块中导入 MultiModalMosaic、MultiModalMosaic9、YOLOv5MultiModalMixUp、YOLOXMultiModalMixUp 类
# 定义 __all__ 列表,包含需要导出的类名
__all__ = ['RandomLoadText', 'LoadText', 'MultiModalMosaic',
           'MultiModalMosaic9', 'YOLOv5MultiModalMixUp',
           'YOLOXMultiModalMixUp']

.\YOLO-World\yolo_world\datasets\utils.py

# 导入必要的库和模块
from typing import Sequence
import torch
from mmengine.dataset import COLLATE_FUNCTIONS
# 注册自定义的数据集拼接函数
@COLLATE_FUNCTIONS.register_module()
def yolow_collate(data_batch: Sequence,
                  use_ms_training: bool = False) -> dict:
    """Rewrite collate_fn to get faster training speed.
    Args:
       data_batch (Sequence): Batch of data.
       use_ms_training (bool): Whether to use multi-scale training.
    """
    # 初始化空列表用于存储数据
    batch_imgs = []
    batch_bboxes_labels = []
    batch_masks = []
    
    # 遍历数据批次
    for i in range(len(data_batch)):
        datasamples = data_batch[i]['data_samples']
        inputs = data_batch[i]['inputs']
        batch_imgs.append(inputs)
        # 获取 ground truth 边界框和标签
        gt_bboxes = datasamples.gt_instances.bboxes.tensor
        gt_labels = datasamples.gt_instances.labels
        
        # 如果数据中包含 masks,则转换为张量并添加到 batch_masks 列表中
        if 'masks' in datasamples.gt_instances:
            masks = datasamples.gt_instances.masks.to_tensor(
                dtype=torch.bool, device=gt_bboxes.device)
            batch_masks.append(masks)
        
        # 创建 batch_idx 用于标识数据批次,拼接边界框和标签
        batch_idx = gt_labels.new_full((len(gt_labels), 1), i)
        bboxes_labels = torch.cat((batch_idx, gt_labels[:, None], gt_bboxes),
                                  dim=1)
        batch_bboxes_labels.append(bboxes_labels)
    # 构建拼接后的结果字典
    collated_results = {
        'data_samples': {
            'bboxes_labels': torch.cat(batch_bboxes_labels, 0)
        }
    }
    
    # 如果存在 masks 数据,则添加到结果字典中
    if len(batch_masks) > 0:
        collated_results['data_samples']['masks'] = torch.cat(batch_masks, 0)
    # 根据是否使用多尺度训练,将输入数据添加到结果字典中
    if use_ms_training:
        collated_results['inputs'] = batch_imgs
    else:
        collated_results['inputs'] = torch.stack(batch_imgs, 0)
    # 如果数据中包含文本信息,则添加到结果字典中
    if hasattr(data_batch[0]['data_samples'], 'texts'):
        batch_texts = [meta['data_samples'].texts for meta in data_batch]
        collated_results['data_samples']['texts'] = batch_texts
    # 检查第一个数据批次中的'data_samples'是否具有'is_detection'属性
    if hasattr(data_batch[0]['data_samples'], 'is_detection'):
        # 如果具有'data_samples'中的'is_detection'属性,则提取每个数据批次中'data_samples'的'is_detection'值
        batch_detection = [meta['data_samples'].is_detection
                           for meta in data_batch]
        # 将提取的'data_samples'中的'is_detection'值转换为torch张量,并存储在collated_results字典中
        collated_results['data_samples']['is_detection'] = torch.tensor(
            batch_detection)
    # 返回整理后的结果字典
    return collated_results

.\YOLO-World\yolo_world\datasets\yolov5_lvis.py

# 导入需要的模块
from mmdet.datasets import LVISV1Dataset
# 导入自定义的数据集类
from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset
from mmyolo.registry import DATASETS
# 注册YOLOv5 LVIS数据集类,继承自BatchShapePolicyDataset和LVISV1Dataset
@DATASETS.register_module()
class YOLOv5LVISV1Dataset(BatchShapePolicyDataset, LVISV1Dataset):
    """Dataset for YOLOv5 LVIS Dataset.
    We only add `BatchShapePolicy` function compared with Objects365V1Dataset.
    See `mmyolo/datasets/utils.py#BatchShapePolicy` for details
    """
    # 空的类定义,没有额外的方法或属性
    pass

.\YOLO-World\yolo_world\datasets\yolov5_mixed_grounding.py

# 导入必要的模块
import os.path as osp
from typing import List, Union
# 导入自定义模块
from mmengine.fileio import get_local_path, join_path
from mmengine.utils import is_abs
from mmdet.datasets.coco import CocoDataset
from mmyolo.registry import DATASETS
from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset
# 注册YOLOv5MixedGroundingDataset类为DATASETS
@DATASETS.register_module()
class YOLOv5MixedGroundingDataset(BatchShapePolicyDataset, CocoDataset):
    """Mixed grounding dataset."""
    # 定义元信息
    METAINFO = {
        'classes': ('object',),
        'palette': [(220, 20, 60)]}
    # 加载数据列表
    def load_data_list(self) -> List[dict]:
        """Load annotations from an annotation file named as ``self.ann_file``
        Returns:
            List[dict]: A list of annotation.
        """  # noqa: E501
        # 使用get_local_path函数获取本地路径
        with get_local_path(
                self.ann_file, backend_args=self.backend_args) as local_path:
            # 使用COCOAPI加载本地路径的数据
            self.coco = self.COCOAPI(local_path)
        # 获取图像ID列表
        img_ids = self.coco.get_img_ids()
        data_list = []
        total_ann_ids = []
        for img_id in img_ids:
            # 加载原始图像信息
            raw_img_info = self.coco.load_imgs([img_id])[0]
            raw_img_info['img_id'] = img_id
            # 获取图像对应的注释ID列表
            ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
            raw_ann_info = self.coco.load_anns(ann_ids)
            total_ann_ids.extend(ann_ids)
            # 解析数据信息
            parsed_data_info = self.parse_data_info({
                'raw_ann_info':
                raw_ann_info,
                'raw_img_info':
                raw_img_info
            })
            data_list.append(parsed_data_info)
        # 检查注释ID是否唯一
        if self.ANN_ID_UNIQUE:
            assert len(set(total_ann_ids)) == len(
                total_ann_ids
            ), f"Annotation ids in '{self.ann_file}' are not unique!"
        # 删除self.coco对象
        del self.coco
        # 返回数据列表
        return data_list
    def filter_data(self) -> List[dict]:
        """Filter annotations according to filter_cfg.
        Returns:
            List[dict]: Filtered results.
        """
        # 如果处于测试模式,则直接返回原始数据列表
        if self.test_mode:
            return self.data_list
        # 如果没有设置过滤配置,则直接返回原始数据列表
        if self.filter_cfg is None:
            return self.data_list
        # 获取过滤空标注和最小尺寸的配置参数
        filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
        min_size = self.filter_cfg.get('min_size', 0)
        # 获取包含标注的图片的 ID 集合
        ids_with_ann = set(data_info['img_id'] for data_info in self.data_list)
        valid_data_infos = []
        # 遍历数据列表,筛选符合条件的数据信息
        for i, data_info in enumerate(self.data_list):
            img_id = data_info['img_id']
            width = int(data_info['width'])
            height = int(data_info['height'])
            # 如果设置了过滤空标注并且当前图片没有标注,则跳过
            if filter_empty_gt and img_id not in ids_with_ann:
                continue
            # 如果图片宽高中的最小值大于等于最小尺寸,则将该数据信息添加到有效数据列表中
            if min(width, height) >= min_size:
                valid_data_infos.append(data_info)
        # 返回筛选后的有效数据信息列表
        return valid_data_infos
    # 将 self.data_root 与 self.data_prefix 和 self.ann_file 连接起来
    def _join_prefix(self):
        """Join ``self.data_root`` with ``self.data_prefix`` and
        ``self.ann_file``.
        """
        # 如果 self.ann_file 不是绝对路径且 self.data_root 存在,则自动将注释文件路径与 self.root 连接起来
        if self.ann_file and not is_abs(self.ann_file) and self.data_root:
            self.ann_file = join_path(self.data_root, self.ann_file)
        # 如果 self.data_prefix 中的路径值不是绝对路径,则自动将数据目录与 self.root 连接起来
        for data_key, prefix in self.data_prefix.items():
            if isinstance(prefix, (list, tuple)):
                abs_prefix = []
                for p in prefix:
                    if not is_abs(p) and self.data_root:
                        abs_prefix.append(join_path(self.data_root, p))
                    else:
                        abs_prefix.append(p)
                self.data_prefix[data_key] = abs_prefix
            elif isinstance(prefix, str):
                if not is_abs(prefix) and self.data_root:
                    self.data_prefix[data_key] = join_path(
                        self.data_root, prefix)
                else:
                    self.data_prefix[data_key] = prefix
            else:
                raise TypeError('prefix should be a string, tuple or list,'
                                f'but got {type(prefix)}')

.\YOLO-World\yolo_world\datasets\yolov5_obj365v1.py

# 导入需要的模块
from mmdet.datasets import Objects365V1Dataset
# 导入自定义的数据集类
from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset
from mmyolo.registry import DATASETS
# 注册YOLOv5Objects365V1Dataset类到DATASETS模块
@DATASETS.register_module()
class YOLOv5Objects365V1Dataset(BatchShapePolicyDataset, Objects365V1Dataset):
    """Dataset for YOLOv5 VOC Dataset.
    We only add `BatchShapePolicy` function compared with Objects365V1Dataset.
    See `mmyolo/datasets/utils.py#BatchShapePolicy` for details
    """
    pass

.\YOLO-World\yolo_world\datasets\yolov5_obj365v2.py

# 导入 Objects365V2Dataset 类
from mmdet.datasets import Objects365V2Dataset
# 导入 BatchShapePolicyDataset 类
from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset
# 导入 DATASETS 注册表
from mmyolo.registry import DATASETS
# 注册 YOLOv5Objects365V2Dataset 类到 DATASETS 注册表
@DATASETS.register_module()
class YOLOv5Objects365V2Dataset(BatchShapePolicyDataset, Objects365V2Dataset):
    """Dataset for YOLOv5 VOC Dataset.
    We only add `BatchShapePolicy` function compared with Objects365V1Dataset.
    See `mmyolo/datasets/utils.py#BatchShapePolicy` for details
    """
    # 空的类定义,继承自 BatchShapePolicyDataset 和 Objects365V2Dataset
    pass

.\YOLO-World\yolo_world\datasets\yolov5_v3det.py

# 导入所需的模块和函数
import copy
import json
import os.path as osp
from typing import List
from mmengine.fileio import get_local_path
from mmdet.datasets.api_wrappers import COCO
from mmdet.datasets import CocoDataset
from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset
from mmyolo.registry import DATASETS
# 定义需要忽略的文件列表
v3det_ignore_list = [
    'a00013820/26_275_28143226914_ff3a247c53_c.jpg',
    'n03815615/12_1489_32968099046_be38fa580e_c.jpg',
    'n04550184/19_1480_2504784164_ffa3db8844_c.jpg',
    'a00008703/2_363_3576131784_dfac6fc6ce_c.jpg',
    'n02814533/28_2216_30224383848_a90697f1b3_c.jpg',
    'n12026476/29_186_15091304754_5c219872f7_c.jpg',
    'n01956764/12_2004_50133201066_72e0d9fea5_c.jpg',
    'n03785016/14_2642_518053131_d07abcb5da_c.jpg',
    'a00011156/33_250_4548479728_9ce5246596_c.jpg',
    'a00009461/19_152_2792869324_db95bebc84_c.jpg',
]
# 注册 V3DetDataset 类
@DATASETS.register_module()
class V3DetDataset(CocoDataset):
    """Objects365 v1 dataset for detection."""
    METAINFO = {'classes': 'classes', 'palette': None}
    COCOAPI = COCO
    # ann_id is unique in coco dataset.
    ANN_ID_UNIQUE = True
# 注册 YOLOv5V3DetDataset 类,继承自 BatchShapePolicyDataset 和 V3DetDataset
@DATASETS.register_module()
class YOLOv5V3DetDataset(BatchShapePolicyDataset, V3DetDataset):
    """Dataset for YOLOv5 VOC Dataset.
    We only add `BatchShapePolicy` function compared with Objects365V1Dataset.
    See `mmyolo/datasets/utils.py#BatchShapePolicy` for details
    """
    pass

.\YOLO-World\yolo_world\datasets\__init__.py

# 导入所需的模块和类
from .mm_dataset import (
    MultiModalDataset, MultiModalMixedDataset)
from .yolov5_obj365v1 import YOLOv5Objects365V1Dataset
from .yolov5_obj365v2 import YOLOv5Objects365V2Dataset
from .yolov5_mixed_grounding import YOLOv5MixedGroundingDataset
from .utils import yolow_collate
from .transformers import *  # NOQA
from .yolov5_v3det import YOLOv5V3DetDataset
from .yolov5_lvis import YOLOv5LVISV1Dataset
# 定义导出的模块和类列表
__all__ = [
    'MultiModalDataset', 'YOLOv5Objects365V1Dataset',
    'YOLOv5Objects365V2Dataset', 'YOLOv5MixedGroundingDataset',
    'YOLOv5V3DetDataset', 'yolow_collate',
    'YOLOv5LVISV1Dataset', 'MultiModalMixedDataset',
]

.\YOLO-World\yolo_world\engine\optimizers\yolow_v5_optim_constructor.py

# 版权声明,版权归腾讯公司所有
import logging
from typing import List, Optional, Union
import torch
import torch.nn as nn
from torch.nn import GroupNorm, LayerNorm
from mmengine.dist import get_world_size
from mmengine.logging import print_log
from mmengine.optim import OptimWrapper, DefaultOptimWrapperConstructor
from mmengine.utils.dl_utils import mmcv_full_available
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm
from mmyolo.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
                             OPTIMIZERS)
# 注册优化器包装器构造函数
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class YOLOWv5OptimizerConstructor(DefaultOptimWrapperConstructor):
    """YOLO World v5 constructor for optimizers."""
    # 初始化函数,接受优化器包装器配置和参数配置
    def __init__(self,
                 optim_wrapper_cfg: dict,
                 paramwise_cfg: Optional[dict] = None) -> None:
        # 调用父类的初始化函数
        super().__init__(optim_wrapper_cfg, paramwise_cfg)
        # 从参数配置中弹出'base_total_batch_size',默认值为64
        self.base_total_batch_size = self.paramwise_cfg.pop(
            'base_total_batch_size', 64)
    # 定义一个方法,用于为模型创建优化器包装器
    def __call__(self, model: nn.Module) -> OptimWrapper:
        # 如果模型有'module'属性,则将'module'属性赋值给model
        if hasattr(model, 'module'):
            model = model.module
        # 复制优化器包装器配置
        optim_wrapper_cfg = self.optim_wrapper_cfg.copy()
        # 设置默认的优化器包装器类型为'OptimWrapper'
        optim_wrapper_cfg.setdefault('type', 'OptimWrapper')
        # 复制优化器配置
        optimizer_cfg = self.optimizer_cfg.copy()
        # 遵循原始的yolov5实现
        if 'batch_size_per_gpu' in optimizer_cfg:
            # 弹出'batch_size_per_gpu'键值对,并赋值给batch_size_per_gpu
            batch_size_per_gpu = optimizer_cfg.pop('batch_size_per_gpu')
            # 计算总批量大小
            total_batch_size = get_world_size() * batch_size_per_gpu
            # 计算累积步数
            accumulate = max(
                round(self.base_total_batch_size / total_batch_size), 1)
            # 计算缩放因子
            scale_factor = total_batch_size * \
                accumulate / self.base_total_batch_size
            # 如果缩放因子不等于1
            if scale_factor != 1:
                # 获取优化器配置中的权重衰减值
                weight_decay = optimizer_cfg.get('weight_decay', 0)
                # 根据缩放因子调整权重衰减值
                weight_decay *= scale_factor
                optimizer_cfg['weight_decay'] = weight_decay
                # 打印调整后的权重衰减值
                print_log(f'Scaled weight_decay to {weight_decay}', 'current')
        # 如果没有指定paramwise选项,则使用全局设置
        if not self.paramwise_cfg:
            # 将模型的参数设置为优化器配置的参数
            optimizer_cfg['params'] = model.parameters()
            # 构建优化器
            optimizer = OPTIMIZERS.build(optimizer_cfg)
        else:
            # 递归设置参数的学习率和权重衰减
            params: List = []
            self.add_params(params, model)
            optimizer_cfg['params'] = params
            optimizer = OPTIMIZERS.build(optimizer_cfg)
        # 构建优化器包装器
        optim_wrapper = OPTIM_WRAPPERS.build(
            optim_wrapper_cfg, default_args=dict(optimizer=optimizer))
        # 返回优化器包装器
        return optim_wrapper

yolo-world 源码解析(五)(2)https://developer.aliyun.com/article/1483890

相关文章
|
2月前
|
监控 网络协议 Java
Tomcat源码解析】整体架构组成及核心组件
Tomcat,原名Catalina,是一款优雅轻盈的Web服务器,自4.x版本起扩展了JSP、EL等功能,超越了单纯的Servlet容器范畴。Servlet是Sun公司为Java编程Web应用制定的规范,Tomcat作为Servlet容器,负责构建Request与Response对象,并执行业务逻辑。
Tomcat源码解析】整体架构组成及核心组件
|
21天前
|
存储 缓存 Java
什么是线程池?从底层源码入手,深度解析线程池的工作原理
本文从底层源码入手,深度解析ThreadPoolExecutor底层源码,包括其核心字段、内部类和重要方法,另外对Executors工具类下的四种自带线程池源码进行解释。 阅读本文后,可以对线程池的工作原理、七大参数、生命周期、拒绝策略等内容拥有更深入的认识。
什么是线程池?从底层源码入手,深度解析线程池的工作原理
|
25天前
|
开发工具
Flutter-AnimatedWidget组件源码解析
Flutter-AnimatedWidget组件源码解析
|
21天前
|
设计模式 Java 关系型数据库
【Java笔记+踩坑汇总】Java基础+JavaWeb+SSM+SpringBoot+SpringCloud+瑞吉外卖/谷粒商城/学成在线+设计模式+面试题汇总+性能调优/架构设计+源码解析
本文是“Java学习路线”专栏的导航文章,目标是为Java初学者和初中高级工程师提供一套完整的Java学习路线。
178 37
|
13天前
|
编解码 开发工具 UED
QT Widgets模块源码解析与实践
【9月更文挑战第20天】Qt Widgets 模块是 Qt 开发中至关重要的部分,提供了丰富的 GUI 组件,如按钮、文本框等,并支持布局管理、事件处理和窗口管理。这些组件基于信号与槽机制,实现灵活交互。通过对源码的解析及实践应用,可深入了解其类结构、布局管理和事件处理机制,掌握创建复杂 UI 界面的方法,提升开发效率和用户体验。
64 12
|
2月前
|
测试技术 Python
python自动化测试中装饰器@ddt与@data源码深入解析
综上所述,使用 `@ddt`和 `@data`可以大大简化写作测试用例的过程,让我们能专注于测试逻辑的本身,而无需编写重复的测试方法。通过讲解了 `@ddt`和 `@data`源码的关键部分,我们可以更深入地理解其背后的工作原理。
30 1
|
2月前
|
算法 安全 Java
深入解析Java多线程:源码级别的分析与实践
深入解析Java多线程:源码级别的分析与实践
|
2月前
|
存储 NoSQL Redis
redis 6源码解析之 object
redis 6源码解析之 object
58 6
|
2月前
|
开发者 Python
深入解析Python `httpx`源码,探索现代HTTP客户端的秘密!
深入解析Python `httpx`源码,探索现代HTTP客户端的秘密!
72 1
|
2月前
|
开发者 Python
深入解析Python `requests`库源码,揭开HTTP请求的神秘面纱!
深入解析Python `requests`库源码,揭开HTTP请求的神秘面纱!
132 1

热门文章

最新文章

推荐镜像

更多
下一篇
无影云桌面