可以通过调用 backward()
方法,只计算网络所有参数梯度而不更新网络,然后使用 detach()
方法来断开计算图与参数之间的连接,以避免在后续的前向传播中对参数进行更新。具体实现如下:
import torch # 定义网络和输入 net = YourNetwork() input = torch.randn(batch_size, input_size) # 计算梯度,但不更新网络 output = net(input) loss = calculate_loss(output) loss.backward(retain_graph=True) # 断开计算图与参数之间的连接 for param in net.parameters(): param.grad.detach_() param.grad.zero_()
以上代码中,retain_graph=True
保留了计算图,以便后续再次计算梯度,而 zero_()
方法将梯度置零以便下一次计算梯度时重新计算。