可以使用 detach()
方法来将参数的梯度值赋值给一个新的变量,并且确保这个变量的值不会随着梯度清零而变为零。示例如下:
import torch # 定义网络和输入 net = YourNetwork() input = torch.randn(batch_size, input_size) # 前向传播和计算损失 output = net(input) loss = calculate_loss(output) # 计算梯度并将其分配给一个新的变量 grads = torch.autograd.grad(loss, net.parameters(), retain_graph=True) grad_vars = [v.detach() for v in grads] # 清空梯度 net.zero_grad() # 在后续使用 grad_vars 计算梯度时,grad_vars 的值不会被修改
在这里,我们首先使用 torch.autograd.grad()
方法计算出每个参数的梯度,并将这些梯度赋值给一个列表 grads
中。然后,通过循环遍历这个列表,并将每个梯度张量通过 detach()
方法创建出一个新的张量 grad_vars
,这个新的张量与原来的梯度张量拥有相同的数值,但是它不再与计算图中的其他节点相连。因此,在后续使用这个新的张量计算梯度时,它的值不会受到原来的梯度张量被清零的影响。