【PyTorch基础教程18】squeeze和unsqueeze

简介: 栗子:torch.nn只支持小批量处理 (mini-batches)。整个 torch.nn 包只支持小批量样本的输入,不支持单个样本的输入。比如,nn.Conv2d 接受一个4维的张量,即nSamples x nChannels x Height x Width,如果是一个单独的样本,只需要使用input.unsqueeze(0) 来添加一个“假的”批大小维度。

学习总结

一、应用场景

栗子:torch.nn只支持小批量处理 (mini-batches)。整个 torch.nn 包只支持小批量样本的输入,不支持单个样本的输入。比如,nn.Conv2d 接受一个4维的张量,即nSamples x nChannels x Height x Width,如果是一个单独的样本,只需要使用input.unsqueeze(0) 来添加一个“假的”批大小维度。


PS:pytorch中,处理图片必须一个batch一个batch的操作,所以我们要准备的数据的格式是 [batch_size, n_channels, hight, width]。


二、升维和降维

降维:squeeze(input, dim = None, out = None)函数

(1)在不指定dim时,张量中形状为1的所有维都会除去。如input为(A, 1, B, 1, C, 1, D),output为(A, B, C, D)。


(2)如果要指定dim,降维操作只能在给定的维度上,如input为(A, 1, B)时:

错误用法:squeeze(input, dim = 0)会发现shape没变化,如下:

d = torch.randn(4, 1, 3)
print("d:", d)
# 没有变化
d1 = torch.squeeze(d, dim = 0) # 还是[4, 1, 3]
print("d1和d1的shape:", d1, d1.shape)
# dim=1处维除去
d2 = torch.squeeze(d, dim = 1) # 变成torch.Size([4, 3])
print("d2和d2的shape:", d2, d2.shape)

结果为:

d: tensor([[[ 1.8679, -0.9913, -2.6257]],
        [[-0.1690, -0.9938,  1.1178]],
        [[-1.2449,  2.5249,  2.2579]],
        [[ 0.2890, -0.5222, -0.2853]]])
d1和d1的shape: tensor([[[ 1.8679, -0.9913, -2.6257]],
        [[-0.1690, -0.9938,  1.1178]],
        [[-1.2449,  2.5249,  2.2579]],
        [[ 0.2890, -0.5222, -0.2853]]]) torch.Size([4, 1, 3])
d2和d2的shape: tensor([[ 1.8679, -0.9913, -2.6257],
        [-0.1690, -0.9938,  1.1178],
        [-1.2449,  2.5249,  2.2579],
        [ 0.2890, -0.5222, -0.2853]]) torch.Size([4, 3])

torch.unsqueeze有两种写法:

# -*- coding: utf-8 -*-
"""
Created on Fri Nov 26 20:08:58 2021
@author: 86493
"""
import torch
a = torch.tensor([1, 2, 3, 4])
print("a的shape:", a.shape)
# 将a的第一维升高 
b = torch.unsqueeze(a, dim = 0)
# b = a.unsqueeze(dim = 0) # 和上面的写法等价 
print("b和b的shape:", b, b.shape)
# 对b降维,去掉所有形状中为1的维 
c = b.squeeze()
print("c和c的shape:", c, c.shape)

结果为如下,即对第一维度升高后,b从a =【4】变为b=【【1, 2, 3,4】】:

a的shape: torch.Size([4])
b和b的shape: tensor([[1, 2, 3, 4]]) torch.Size([1, 4])
c和c的shape: tensor([1, 2, 3, 4]) torch.Size([4])
相关文章
|
1月前
|
存储 物联网 PyTorch
基于PyTorch的大语言模型微调指南:Torchtune完整教程与代码示例
**Torchtune**是由PyTorch团队开发的一个专门用于LLM微调的库。它旨在简化LLM的微调流程,提供了一系列高级API和预置的最佳实践
179 59
基于PyTorch的大语言模型微调指南:Torchtune完整教程与代码示例
|
1月前
|
并行计算 监控 搜索推荐
使用 PyTorch-BigGraph 构建和部署大规模图嵌入的完整教程
当处理大规模图数据时,复杂性难以避免。PyTorch-BigGraph (PBG) 是一款专为此设计的工具,能够高效处理数十亿节点和边的图数据。PBG通过多GPU或节点无缝扩展,利用高效的分区技术,生成准确的嵌入表示,适用于社交网络、推荐系统和知识图谱等领域。本文详细介绍PBG的设置、训练和优化方法,涵盖环境配置、数据准备、模型训练、性能优化和实际应用案例,帮助读者高效处理大规模图数据。
54 5
|
4月前
|
并行计算 Ubuntu PyTorch
Ubuntu下CUDA、Conda、Pytorch联合教程
本文是一份Ubuntu系统下安装和配置CUDA、Conda和Pytorch的教程,涵盖了查看显卡驱动、下载安装CUDA、添加环境变量、卸载CUDA、Anaconda的下载安装、环境管理以及Pytorch的安装和验证等步骤。
831 1
Ubuntu下CUDA、Conda、Pytorch联合教程
|
7月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.2 中文官方教程(十八)(1)
PyTorch 2.2 中文官方教程(十八)
236 2
PyTorch 2.2 中文官方教程(十八)(1)
|
7月前
|
并行计算 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十七)(4)
PyTorch 2.2 中文官方教程(十七)
235 2
PyTorch 2.2 中文官方教程(十七)(4)
|
7月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.2 中文官方教程(十九)(1)
PyTorch 2.2 中文官方教程(十九)
146 1
PyTorch 2.2 中文官方教程(十九)(1)
|
7月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十八)(3)
PyTorch 2.2 中文官方教程(十八)
102 1
PyTorch 2.2 中文官方教程(十八)(3)
|
7月前
|
API PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十八)(2)
PyTorch 2.2 中文官方教程(十八)
201 1
PyTorch 2.2 中文官方教程(十八)(2)
|
7月前
|
异构计算 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十七)(3)
PyTorch 2.2 中文官方教程(十七)
103 1
PyTorch 2.2 中文官方教程(十七)(3)
|
7月前
|
PyTorch 算法框架/工具 机器学习/深度学习
PyTorch 2.2 中文官方教程(十七)(2)
PyTorch 2.2 中文官方教程(十七)
161 1
PyTorch 2.2 中文官方教程(十七)(2)