TorchVision Transforms API 大升级,支持目标检测、实例/语义分割及视频类任务

简介: TorchVision Transforms API 大升级,支持目标检测、实例/语义分割及视频类任务


内容导读:TorchVision Transforms API 扩展升级,现已支持目标检测、实例及语义分割以及视频类任务。新 API 尚处于测试阶段,开发者可以试用体验。

本文首发自微信公众号:PyTorch 开发者社区

image.png

TorchVision 现已针对 Transforms API 进行了扩展, 具体如下:

  • 除用于图像分类外,现在还可以用其进行目标检测、实例及语义分割以及视频分类等任务;
  • 支持从 TorchVision 直接导入 SoTA 数据增强,如 MixUp、 CutMix、Large Scale Jitter 以及 SimpleCopyPaste。
  • 支持使用全新的 functional transforms 转换视频、Bounding box 以及分割掩码 (Segmentation Mask)。

Transforms 当前的局限性

稳定版 TorchVision Transforms API,也也就是我们常说的 Transforms V1,只支持单个图像,因此,只适用于分类任务:

from torchvision import transforms
trans = transforms.Compose([
   transforms.ColorJitter(contrast=0.5),
   transforms.RandomRotation(30),
   transforms.CenterCrop(480),
])
imgs = trans(imgs)

上述方法不支持需要使用 Label 的目标检测、分割或分类 Transforms, 如 MixUp 及 cutMix。这使分类以外的计算机视觉任务都不能用 Transforms API 执行必要的扩展。同时,这也加大了用 TorchVision 原语训练高精度模型的难度。

为了克服这个局限性,TorchVision 在其 reference script 中提供了自定义实现, 用于演示所有任务中的增强是如何执行的。

尽管这种做法使得开发者能够训练出高精度的分类、目标检测及分割模型,但做法比较粗糙,TorchVision 二进制文件中还是不能导入 Transforms。

全新的 Transforms API

Transforms V2 API 支持视频、bounding box、label 以及分割掩码, 这意味着它为许多计算机视觉任务提供了本地支持。新的解决方案是一种更为直接的替代方案:

from torchvision.prototype import transforms
# Exactly the same interface as V1:
trans = transforms.Compose([
    transforms.ColorJitter(contrast=0.5),
    transforms.RandomRotation(30),
    transforms.CenterCrop(480),
])
imgs, bboxes, labels = trans(imgs, bboxes, labels)

全新的 Transform Class 无需强制执行特定的顺序或结构,就可以接收任意数量的输入:

# Already supported:
trans(imgs)  # Image Classification
trans(videos)  # Video Tasks
trans(imgs_or_videos, labels)  # MixUp/CutMix-style Transforms
trans(imgs, bboxes, labels)  # Object Detection
trans(imgs, bboxes, masks, labels)  # Instance Segmentation
trans(imgs, masks)  # Semantic Segmentation
trans({"image": imgs, "box": bboxes, "tag": labels})  # Arbitrary Structure
# Future support:
trans(imgs, bboxes, labels, keypoints)  # Keypoint Detection
trans(stereo_images, disparities, masks)  # Depth Perception
trans(image1, image2, optical_flows, masks)  # Optical Flow

functional API 已经更新,支持所有输入必要的 signal processing kernel,如 resizing, cropping, affine transforms, padding 等:

from torchvision.prototype.transforms import functional as F
# High-level dispatcher, accepts any supported input type, fully BC
F.resize(inpt, resize=[224, 224])
# Image tensor kernel
F.resize_image_tensor(img_tensor, resize=[224, 224], antialias=True)
# PIL image kernel
F.resize_image_pil(img_pil, resize=[224, 224], interpolation=BILINEAR)
# Video kernel
F.resize_video(video, resize=[224, 224], antialias=True)
# Mask kernel
F.resize_mask(mask, resize=[224, 224])
# Bounding box kernel
F.resize_bounding_box(bbox, resize=[224, 224], spatial_size=[256, 256])

API 使用 Tensor subclassing 来包装输入,附加有用的元数据,并 dispatch 到正确的内核。 利用 TorchData Data Pipe 的 Datasets V2 相关工作完成后,就不再需要手动包装输入了。目前,用户可以通过以下方式手动包装输入:

from torchvision.prototype import features
imgs = features.Image(images, color_space=ColorSpace.RGB)
vids = features.Video(videos, color_space=ColorSpace.RGB)
masks = features.Mask(target["masks"])
bboxes = features.BoundingBox(target["boxes"], format=BoundingBoxFormat.XYXY, spatial_size=imgs.spatial_size)
labels = features.Label(target["labels"], categories=["dog", "cat"])

除新 API 之外,PyTorch 官方还为 SoTA 研究中用到的一些数据增强提供了重要实现,如 MixUp、 CutMix、Large Scale Jitter、 SimpleCopyPaste、AutoAugmentation 方法以及一些新的 Geometric、Colour 和 Type Conversion transforms。

该 API 继续支持 single image 或 batched input image 的 PIL 和 Tensor 后端,并在 functional API 上保留了 JIT-scriptability。这使得图像映射得以从 uint8 延迟到 float, 带来了性能的进一步提升。

它目前可以在 TorchVision 的原型区域 (prototype area) 中使用,并且支持从 nightly build 版本中导入。经验证,新 API 与先前实现的准确性一致。

当前的局限性

functional API (kernel) 仍然保持 JIT-scriptable 及 fully-BC,Transform Class 提供了相同的接口,却无法使用脚本。

这是因为 Transform Class 使用的是张量子类 (Tensor Subclassing),且接收任意数量的输入,这是 JIT 所不支持的。该局限将在后续版本中不断优化。

一个端到端示

以下是一个新 API 示例,它可以同时使用 PIL 图像和张量。

测试图片:

image.png

代码示例:

import PIL
from torchvision import io, utils
from torchvision.prototype import features, transforms as T
from torchvision.prototype.transforms import functional as F
# Defining and wrapping input to appropriate Tensor Subclasses
path = "COCO_val2014_000000418825.jpg"
img = features.Image(io.read_image(path), color_space=features.ColorSpace.RGB)
# img = PIL.Image.open(path)
bboxes = features.BoundingBox(
    [[2, 0, 206, 253], [396, 92, 479, 241], [328, 253, 417, 332],
     [148, 68, 256, 182], [93, 158, 170, 260], [432, 0, 438, 26],
     [422, 0, 480, 25], [419, 39, 424, 52], [448, 37, 456, 62],
     [435, 43, 437, 50], [461, 36, 469, 63], [461, 75, 469, 94],
     [469, 36, 480, 64], [440, 37, 446, 56], [398, 233, 480, 304],
     [452, 39, 463, 63], [424, 38, 429, 50]],
    format=features.BoundingBoxFormat.XYXY,
    spatial_size=F.get_spatial_size(img),
)
labels = features.Label([59, 58, 50, 64, 76, 74, 74, 74, 74, 74, 74, 74, 74, 74, 50, 74, 74])
# Defining and applying Transforms V2
trans = T.Compose(
    [
        T.ColorJitter(contrast=0.5),
        T.RandomRotation(30),
        T.CenterCrop(480),
    ]
)
img, bboxes, labels = trans(img, bboxes, labels)
# Visualizing results
viz = utils.draw_bounding_boxes(F.to_image_tensor(img), boxes=bboxes)
F.to_pil_image(viz).show()

—— 完 ——

相关文章
|
8月前
|
人工智能 安全 API
Higress MCP Server 安全再升级:API 认证为 AI 连接保驾护航
Higress MCP Server 新增了 API 认证功能,为 AI 连接提供安全保障。主要更新包括:1) 客户端到 MCP Server 的认证,支持 Key Auth、JWT Auth 和 OAuth2;2) MCP Server 到后端 API 的认证,增强第二阶段的安全性。新增功能如可重用认证方案、工具特定后端认证、透明凭证透传及灵活凭证管理,确保安全集成更多后端服务。通过 openapi-to-mcp 工具简化配置,减少手动工作量。企业版提供更高可用性保障,详情参见文档链接。
962 42
|
6月前
|
人工智能 供应链 安全
AI驱动攻防升级,API安全走到关键档口
在AI与数字化转型加速背景下,API已成为企业连接内外业务的核心枢纽,但其面临的安全威胁也日益严峻。瑞数信息发布的《API安全趋势报告》指出,2024年API攻击流量同比增长162%,占所有网络攻击的78%。攻击呈现规模化、智能化、链式扩散等新特征,传统防护手段已难应对。报告建议企业构建覆盖API全生命周期的安全体系,强化资产梳理、访问控制、LLM防护、供应链管控等七大能力,提升动态防御水平,保障AI时代下的业务安全与稳定。
238 0
|
5月前
|
JSON 监控 API
抖音视频列表API秘籍!轻松获取视频列表数据
抖音视频列表API是抖音开放平台提供的核心接口,支持按关键词、分类、排序方式筛选视频,适用于内容推荐、趋势分析等场景。接口返回含视频ID、标题、播放量等50+字段,支持分页获取,通过HTTP GET请求调用,返回JSON格式数据,便于开发者快速集成与处理。需注册平台账号获取访问权限。
1273 56
|
5月前
|
JSON 监控 API
抖音视频详情API秘籍!轻松获取视频详情数据
抖音视频详情API是抖音开放平台的核心接口,通过视频ID可获取包括标题、播放量、点赞数、评论等50多个字段,适用于内容分析、竞品监控和广告评估等场景。接口支持HTTP GET请求,返回JSON格式数据,便于解析处理。文中还提供了使用Python调用该接口的示例代码,包含请求发送、认证、响应处理等功能,帮助开发者快速获取视频数据。
1244 5
|
5月前
|
API 开发工具 开发者
客流类API实测:门店到访客群画像数据
本文介绍了一个实用的API——“门店到访客群画像分布”,适用于线下实体门店进行客群画像分析。该API支持多种画像维度,如性别、年龄、职业、消费偏好等,帮助商家深入了解顾客特征,提升运营效率。文章详细说明了API的参数配置、响应数据、接入流程,并附有Python调用示例,便于开发者快速集成。适合零售、餐饮等行业从业者使用。
客流类API实测:门店到访客群画像数据
|
8月前
|
数据采集 安全 大数据
Dataphin 5.1:API数据源及管道组件升级,适配多样化认证的API
为提升API数据交互安全性,Dataphin 5.1推出两种新认证方式:基于OAuth 2.0的动态授权与请求签名认证。前者通过短期Access Token确保安全,后者对关键参数加密签名保障数据完整性。功能支持API数据源OAuth 2.0认证和自定义签名配置,未来还将拓展更灵活的认证方式以满足多样化需求。
240 14
|
9月前
|
存储 安全 API
秘密任务 1.0:为什么 DTO 是 API 设计效率和安全性的秘密武器?
在软件开发中,确保API安全与高效至关重要。本文通过“间谍机构”场景,介绍数据传输对象(DTO)的作用。DTO是一种设计模式,用于格式化数据并隐藏敏感信息,仅传送必要内容。例如,在特工数据中,DTO可过滤掉密码和任务详情,仅返回代号和权限等级。使用DTO能简化前后端通信、提升性能和安全性。 文中示例展示如何用DTO处理GET与POST请求:GET响应只含安全字段,POST创建新特工时隐藏密码。借助工具如APIPost,可更高效管理API设计,实现安全、结构化的数据交互。总结来说,DTO让API更简洁、安全且高效。
|
9月前
|
JSON 搜索推荐 API
深入研究:京东商品视频 API 详解
京东商品视频API简介:该API可基于京东商品ID获取商品视频信息,包括标题、描述、播放地址、缩略图及视频时长等,助力开发者和商家实现个性化展示与智能推荐。接口采用HTTP GET方式请求,返回JSON格式数据。示例代码展示了通过Python的requests库调用API并生成签名的过程,确保请求安全可靠。此API有助于提升电商应用的用户体验与竞争力。
|
9月前
|
JSON API 开发者
京东API最新指南:商品视频接口接入与应用
在电商领域,商品视频能有效提升销售业绩。京东商品视频接口助力开发者获取商品视频信息(播放链接、时长、格式、封面图等),通过 HTTP GET/POST 请求返回 JSON 数据,便于集成到各类应用中,优化展示效果与用户体验。本指南详解接口接入与使用方法。