torch.autograd.Function 学习理解

简介: 文章目录前言一、概述二、例程三、官方的demo(指数函数)

前言


在量化感知训练,为了能够进行反向传播,会引入直通估计器,用于保证参数可以求导。我们需要自己定义这些操作,且定义反向求导函数,由于基础知识薄弱,便仔细学习了相关知识。

一、概述

torch.autograd.Function

只需要实现两个 静态方法:


forward可以有任意多个输入、任意多个输出,但是输入和输出必须是Variable。

backward的输入和输出的个数就是forward()函数的输出和输入的个数。其中,backward()输入表示关于forward()输出的梯度,backward()的输出表示关于forward()的输入的梯度。

另外还要加上ctx,它可以理解为一个上下文管理器。


定义新的操作,意味着定义Function的子类,并且这些子类必须重写以下函数:forward()backward()。初始化函数:__init__()根据实际需求判断是否需要重写。

二、例程

from torch.autograd import Function
class MultiplyAdd(Function):
    @staticmethod
    def forward(ctx, w, x, b):
        print('type in forward', type(x))
        ctx.save_for_backward(w, x)#存储用来反向传播的参数
        output = w*x +b
        return output
    @staticmethod
    def backward(ctx, grad_output):
        w, x = ctx.saved_tensors #deprecated,现在使用saved_tensors
        print('type in backward',type(x))
        grad_w = grad_output * x
        grad_x = grad_output * w
        grad_b = grad_output * 1
        return grad_w, grad_x, grad_b
w = torch.rand(2, 2, requires_grad=True)
x = torch.rand(2, 2, requires_grad=True)
b = torch.rand(2, 2, requires_grad=True)
out = MultiplyAdd.apply(w, x, b)
out.backward(torch.ones(2,2))
w.grad,x.grad,b.grad
(tensor([[0.5159, 0.4950],
         [0.1050, 0.7115]]),
 tensor([[0.6249, 0.4731],
         [0.7905, 0.1637]]),
 tensor([[1., 1.],
         [1., 1.]]))

三、官方的demo(指数函数)

import torch
import torch.nn.functional as F
from torch.autograd import Function
class MyExp(Function):
    @staticmethod
    def forward(ctx, i):
        result = i.exp()
        ctx.save_for_backward(result) #将result转移到Variable保存在ctx中
        return result
    @staticmethod
    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        return grad_output * result
def exp(x):
    return MyExp.apply(x)
x = torch.Tensor([0,2]).requires_grad_(True)
out = MyExp.apply(x)
out.backward(torch.Tensor([0,2]))
x.grad
tensor([ 0.0000, 14.7781])
目录
打赏
0
0
0
0
1
分享
相关文章
Pytorch学习笔记(五):nn.AdaptiveAvgPool2d()函数详解
PyTorch中的`nn.AdaptiveAvgPool2d()`函数用于实现自适应平均池化,能够将输入特征图调整到指定的输出尺寸,而不需要手动计算池化核大小和步长。
295 1
Pytorch学习笔记(五):nn.AdaptiveAvgPool2d()函数详解
【Tensorflow+keras】解决使用model.load_weights时报错 ‘str‘ object has no attribute ‘decode‘
python 3.6,Tensorflow 2.0,在使用Tensorflow 的keras API,加载权重模型时,报错’str’ object has no attribute ‘decode’
82 0
Python报错ValueError: The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().
Python报错ValueError: The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().
1834 1
torch.argmax(dim=1)用法
)torch.argmax(input, dim=None, keepdim=False)返回指定维度最大值的序号;
657 0
Pytorch出现RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor)
这个问题的主要原因是输入的数据类型与网络参数的类型不符。
731 0
详细介绍torch中的from torch.utils.data.sampler相关知识
PyTorch中的torch.utils.data.sampler模块提供了一些用于数据采样的类和函数,这些类和函数可以用于控制如何从数据集中选择样本。下面是一些常用的Sampler类和函数的介绍: Sampler基类: Sampler是一个抽象类,它定义了一个__iter__方法,返回一个迭代器,用于生成数据集中的样本索引。 RandomSampler: 随机采样器,它会随机从数据集中选择样本。可以设置随机数种子,以确保每次采样结果相同。 SequentialSampler: 顺序采样器,它会按照数据集中的顺序,依次选择样本。 SubsetRandomSampler: 子集随机采样器
680 0
成功解决numpy.core._internal.AxisError: axis -1 is out of bounds for array of dimension 0
成功解决numpy.core._internal.AxisError: axis -1 is out of bounds for array of dimension 0
AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等