(读Yolo3源码发现的不会的函数)Pytorch常用函数记录-pretrained-torch.nn.Upsample()函数-torch.cat-a.permute-a.view()等

简介: (读Yolo3源码发现的不会的函数)Pytorch常用函数记录-pretrained-torch.nn.Upsample()函数-torch.cat-a.permute-a.view()等

1、pretrained = False

我们经常会在pytorch的代码中看到这个参数,可以设置为True,也可以设置为False.

事实上这个参数常见于迁移学习的代码中,如果设置为True,则是启动下载预训练模型。

如果设置为False,则是不下载预训练模型,我们一般喜欢提前手动下载好,放置到相应的路径,因为一般设置为False。

2、torch.nn.Upsample()函数

实现上采样

import torch.nn as nn
nn.Upsample(size=None, scale_factor=None, mode='nearest', align_corners=None)

size:根据输入的不同,而确定输出的不同

scale_factor:确定输出为输入的多少倍

mode:可设置不同的上采用算法,nearest,linear,bilinear,bicubic , trilinear五种。默认使用nearest;

align_corners: 当设置为True,输入的角像素将与输出张量对齐,且保存下来这些像素的值,使用的算法为'linear', 'bilinear'or 'trilinear'时使用。默认设置为False。

3、torch.cat张量拼接函数

事实上对于做深度学习的人来讲,torch.cat用得是非常频繁的。特别实在目标检测算法中,我们经常需要进行不同维度的特征拼接,特征融合,这些大都是需要用到torch.cat。好的 废话少说,我们开始学torch.cat。


torch.cat是将两个张量拼接在一起,torch.cat((a, b ) , dim),这里a , b即为需要拼接的两个张量,dim为拼接的维度,当dim=0的时候为按 行 方向进行拼接,当dim = 1的时候为按照列的方向进行拼接。


4、a.permute维度替换函数

顾名思义就是将不同维度的数进行维度更换,因为在pytorch的运算处理当中,经常会出现多维度张量,而且不同维度的经常需要更换,因此我们就用到这个函数。

import torch
a = torch.randn(32 , 3 , 85 , 13 , 13)
b = a.permute(0 , 1 , 3 , 4 , 2)
print(b.shape)

输出结果:

(32 , 3 , 13 , 13 , 85)

可以看出来将2 , 3 , 4 维度的数进行更换,即维度更换。

5、a.view()张量重新resize函数

为了方便我们使用,我们用最简单的方法讲解函数最实用的用处,不深究原理。


pytorch里面的view()可以理解成numpy里面的resize(),都是对数据进行重新调整改变他的shape。


例如:原来数据a的shape是(6*8),a.view(24*2),则数据就会从6行8列,变成24行2列。


6、a.contiguous()函数

对于pytorch中的contiguous()函数,不同的人有不同的理解,在我的理解中我把他理解成一个深拷贝函数。


a.contiguous()方法常与a.permute()、a.transpose()、a.view()方法同时使用,对于这三个方法来说他们不会改变a在底层的存储方式,只是将输出形式以我们想看见的方式输出了(即只是改变了张量的输出形状)并没有开辟新内存,存储这个数据,如果对于想创建一个完全跟上诉脱离的数,则需要加.contiguous(),将数据深度拷贝,即copy了一个新数据跟之前的没有关联。


7、a.index_select()函数

这个可以理解成张量切片,本来这个想自己写的但是这位大佬写的太好了,直接看这位大佬写的把。


【python函数】torch.index_select()函数用法解析_风巽·剑染春水的博客-CSDN博客


8、pytorch.data属性和.detach属性

这两个的作用都是从正在进行的梯度张量运算中获取他的tensor值,他俩有相同的地方也有不同的地方。.data 和.detach都只取出本体tensor数据,舍弃了grad,grad_fn等额外反向图计算过程需保存的额外信息。就是将一个啰嗦的量,只获取他本身的值。


相同地方:-两者都是与原来的数据共享一个数据

-都是require s_grad = False

不同地方:-在进行反向传播的过程中如果会修改原数据,而.data会直接根据修改后的跟原数据不同的值进行错误的计算,而.detach会直报错,告诉你数据修改了。总之就是.detach会跟原来的数据还有关联,而且.data已经没有关联了,.detach如果发现原来的数据有变化了会告诉你,但是.data就会按照错误的走下去。


9、[..., i]

[..., i]相当于[:, :, … :, i],两者的效果是相等的,几个点就对应几个冒号。


例如[..., i]对应[:, :, :, i],[...., i]对应[:, :, : , : , i]。


可以将多维数据理解成切片,就是一层一层的片堆叠在一起,这样就方便理解我们这个的作用,就是为了提取相应的那一片数据,这一片可能从第一个维度的提取,也可以第二个维度等。

import torch
a = torch.rand((2, 8, 42, 4))  # 有4片数据
b = a[..., 0]  # 取第1片上的所有数据 ,数据size=[2,8,42]
c = a[:, :, :, 0]  # 取第1片上的所有数据 ,数据size=[2,8,42]
print(b.size())#类似数组切片
print(c.size())#一堆维度的贴片粘在一起,依次取第几维的
d = torch.rand((8, 42, 4))
b1 = d[..., 0]  # 取第1片上的所有数据 ,数据size=[8,42]
c1 = d[:, :, 0]  # 取第1片上的所有数据 ,数据size=[8,12]
c2 = d[:, 40]  # 取42个片上的第一行第一列的数据 数据size=[8,4]
# print(b1, c2)
print(b1.size(),c1.size(), c2.size())
输入结果:
torch.Size([2, 8, 42])
torch.Size([2, 8, 42])
torch.Size([8, 42]) torch.Size([8, 42]) torch.Size([8, 4])
相关文章
|
5月前
|
机器学习/深度学习 数据可视化 PyTorch
PyTorch基础之模型保存与重载模块、可视化模块讲解(附源码)
PyTorch基础之模型保存与重载模块、可视化模块讲解(附源码)
57 1
|
5月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch基础之网络模块torch.nn中函数和模板类的使用详解(附源码)
PyTorch基础之网络模块torch.nn中函数和模板类的使用详解(附源码)
74 0
|
5月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch基础之激活函数模块中Sigmoid、Tanh、ReLU、LeakyReLU函数讲解(附源码)
PyTorch基础之激活函数模块中Sigmoid、Tanh、ReLU、LeakyReLU函数讲解(附源码)
68 0
|
5月前
|
数据采集 PyTorch 算法框架/工具
PyTorch基础之数据模块Dataset、DataLoader用法详解(附源码)
PyTorch基础之数据模块Dataset、DataLoader用法详解(附源码)
419 0
|
5月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch基础之张量模块数据类型、基本操作、与Numpy数组的操作详解(附源码 简单全面)
PyTorch基础之张量模块数据类型、基本操作、与Numpy数组的操作详解(附源码 简单全面)
36 0
|
5月前
|
机器学习/深度学习 人工智能 算法
【PyTorch深度强化学习】TD3算法(双延迟-确定策略梯度算法)的讲解及实战(超详细 附源码)
【PyTorch深度强化学习】TD3算法(双延迟-确定策略梯度算法)的讲解及实战(超详细 附源码)
451 1
|
2月前
|
机器学习/深度学习 算法 大数据
基于PyTorch对凸函数采用SGD算法优化实例(附源码)
基于PyTorch对凸函数采用SGD算法优化实例(附源码)
33 3
|
2月前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch的机器学习Regression问题实例(附源码)
基于Pytorch的机器学习Regression问题实例(附源码)
33 1
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch, 16个超强转换函数总结 ! !
PyTorch, 16个超强转换函数总结 ! !
35 1
|
5月前
|
机器学习/深度学习 传感器 算法
PyTorch基础之优化器模块、训练和测试模块讲解(附源码)
PyTorch基础之优化器模块、训练和测试模块讲解(附源码)
73 0