Pytorch 常见运算(mul、mm、dot、mv)

简介: Pytorch 常见运算(mul、mm、dot、mv)

1.矩阵与标量


矩阵(张量)每一个元素与标量进行操作。


import torch
a = torch.tensor([1,2])
print(a+1)
>>> tensor([2, 3])


2.哈达玛积(mul


两个相同尺寸的张量相乘,然后对应元素的相乘就是这个哈达玛积。


a = torch.tensor([1,2])
b = torch.tensor([2,3])
print(a*b)
print(torch.mul(a,b))
>>> tensor([2, 6])
>>> tensor([2, 6])


这个torch.mul()和*以及torch.dot()是等价的


当然,除法也是类似的:


a = torch.tensor([1.,2.])
b = torch.tensor([2.,3.])
print(a/b)
print(torch.div(a/b))
>>> tensor([0.5000, 0.6667])
>>> tensor([0.5000, 0.6667])


我们可以发现的torch.div()其实就是/, 类似的:torch.add就是+,torch.sub()就是-,不过符号的运算更简单常用。


3.矩阵乘法


在代码中矩阵相乘有三种写法:


  • torch.mm()
  • torch.matmul()
  • @


a = torch.tensor([1.,2.])
b = torch.tensor([2.,3.]).view(1,2)
print(torch.mm(a, b))
print(torch.matmul(a, b))
print(a @ b)


输出结果:


tensor([[2., 3.],
        [4., 6.]])
tensor([[2., 3.],
        [4., 6.]])
tensor([[2., 3.],
        [4., 6.]])


上面的是对二维矩阵而言的,假如参与运算的是一个多维张量,那么只有torch.matmul()可以使用

torch.mv()等价于torch.mm(),不过不同的是mv适用与矩阵和向量相乘

在多维张量中,参与矩阵运算的其实只有后两个维度,前面的维度其实就像是索引一样,举个例子:


a = torch.rand((1,2,64,32))
b = torch.rand((1,2,32,64))
print(torch.matmul(a, b).shape)
>>> torch.Size([1, 2, 64, 64])


4.幂与开方


a = torch.tensor([1.,2.])
b = torch.tensor([2.,3.])
c1 = a ** b
c2 = torch.pow(a, b)
print(c1,c2)
>>> tensor([1., 8.]) tensor([1., 8.])


5.对数运算


pytorch中log是以e自然数为底数的,然后log2和log10才是以2和10为底数的运算。


import numpy as np
print('对数运算')
a = torch.tensor([2,10,np.e])
print(torch.log(a))
print(torch.log2(a))
print(torch.log10(a))
>>> tensor([0.6931, 2.3026, 1.0000])
>>> tensor([1.0000, 3.3219, 1.4427])
>>> tensor([0.3010, 1.0000, 0.4343]) 


6.近似值运算


  • .ceil() 向上取整
  • .floor()向下取整
  • .trunc()取整数
  • .frac()取小数
  • .round()四舍五入


a = torch.tensor(1.2345)
print(a.ceil())
>>>tensor(2.)
print(a.floor())
>>> tensor(1.)
print(a.trunc())
>>> tensor(1.)
print(a.frac())
>>> tensor(0.2345)
print(a.round())
>>> tensor(1.)


7.剪裁运算


这个是让一个数,限制在你自己设置的一个范围内[min,max],小于min的话就被设置为min,大于max的话就被设置为max。这个操作在一些对抗生成网络中,好像是WGAN-GP,通过强行限制模型的参数的值。


a = torch.rand(5)
print(a)
print(a.clamp(0.3,0.7))


输出为:


tensor([0.5271, 0.6924, 0.9919, 0.0095, 0.0340])
tensor([0.5271, 0.6924, 0.7000, 0.3000, 0.3000])

7ddb6c3069b240198e95af5acb4bb5a4.png


相关文章
|
4月前
|
TensorFlow 算法框架/工具
【Tensorflow】图解tf.image.extract_patches的用法--提取图片特定区域
文章通过图解和示例详细解释了TensorFlow中tf.image.extract_patches函数的用法,展示了如何使用该函数从图像中提取特定区域或分割图像为多个子图像。
77 0
|
5月前
|
数据挖掘 开发者 索引
【Python】已解决:ValueError: If using all scalar values, you must pass an index
【Python】已解决:ValueError: If using all scalar values, you must pass an index
1886 0
|
Docker 容器
求助: 运行模型时报错module 'megatron_util.mpu' has no attribute 'get_model_parallel_rank'
运行ZhipuAI/Multilingual-GLM-Summarization-zh的官方代码范例时,报错AttributeError: MGLMTextSummarizationPipeline: module 'megatron_util.mpu' has no attribute 'get_model_parallel_rank' 环境是基于ModelScope官方docker镜像,尝试了各个版本结果都是一样的。
415 5
|
Python
Python报错ValueError: The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().
Python报错ValueError: The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().
1647 1
|
7月前
GEE错误——Tile error: Arrays must have same lengths on all axes but the cat axis
GEE错误——Tile error: Arrays must have same lengths on all axes but the cat axis
69 1
|
7月前
|
Shell 计算机视觉 Python
no module named cv2 、numpy 、xxx超全解决方案
no module named cv2 、numpy 、xxx超全解决方案
torch.argmax(dim=1)用法
)torch.argmax(input, dim=None, keepdim=False)返回指定维度最大值的序号;
648 0
|
存储 Serverless
[oeasy]python0073_进制转化_eval_evaluate_衡量_oct_octal_八进制
[oeasy]python0073_进制转化_eval_evaluate_衡量_oct_octal_八进制
81 0
python SMAP_level2c nc 文件做线性拟合:y=ax+b
最近再处理卫星盐度数据时,通过时空匹配以及质量控制之后,需要对所得数据进行拟合分析。进而分析其误差分布、原因等。 根据学习,python中自带线性拟合的函数,使用起来较为方便快捷~
python SMAP_level2c nc 文件做线性拟合:y=ax+b
|
PyTorch 算法框架/工具
pytorch中meter.AverageValueMeter()使用方法
pytorch中meter.AverageValueMeter()使用方法
289 0