# 【从零开始学习深度学习】38. Pytorch实战案例：梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】

## 2. 读取训练数据

%matplotlib inline
import numpy as np
import time
import torch
from torch import nn, optim
import sys
import d2lzh_pytorch as d2l
def get_data_ch7():
data = np.genfromtxt('./data/airfoil_self_noise.dat', delimiter='\t')
# 标准化数据
data = (data - data.mean(axis=0)) / data.std(axis=0)
torch.tensor(data[:1500, -1], dtype=torch.float32) # 前1500个样本(每个样本包含5个特征)
features, labels = get_data_ch7()
features.shape # torch.Size([1500, 5])

## 3. 从零实现3种梯度算法并进行训练

# 参数优化器
def sgd(params, states, hyperparams):
for p in params:
# 训练函数
def train_ch7(optimizer_fn, states, hyperparams, features, labels,
batch_size=10, num_epochs=2):
# 初始化模型，初始化一个线性回归模型
net, loss = d2l.linreg, d2l.squared_loss

w = torch.nn.Parameter(torch.tensor(np.random.normal(0, 0.01, size=(features.shape[1], 1)), dtype=torch.float32),
def eval_loss():
return loss(net(features, w, b), labels).mean().item()
ls = [eval_loss()]
torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)

for _ in range(num_epochs):
start = time.time()
for batch_i, (X, y) in enumerate(data_iter):
l = loss(net(X, w, b), y).mean()  # 使用平均损失

# 梯度清零

l.backward()
optimizer_fn([w, b], states, hyperparams)  # 迭代模型参数
if (batch_i + 1) * batch_size % 100 == 0:
ls.append(eval_loss())  # 每100个样本记录下当前训练误差
# 打印结果和作图
print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))
d2l.set_figsize()
d2l.plt.plot(np.linspace(0, num_epochs, len(ls)), ls)
d2l.plt.xlabel('epoch')
d2l.plt.ylabel('loss')

### 3.1 梯度下降训练结果

def train_sgd(lr, batch_size, num_epochs=2):
train_ch7(sgd, None, {'lr': lr}, features, labels, batch_size, num_epochs)
train_sgd(1, 1500, 6)

loss: 0.245426, 0.013536 sec per epoch

### 3.2 随机梯度下降将结果

train_sgd(0.005, 1)

loss: 0.246051, 0.531435 sec per epoch

### 3.3 小批量随机梯度下降结果

train_sgd(0.05, 10)

loss: 0.242805, 0.078792 sec per epoch

## 4 .使用Pytorch的optim.SGD实现梯度下降优化算法

def train_pytorch_ch7(optimizer_fn, optimizer_hyperparams, features, labels,
batch_size=10, num_epochs=2):
# 初始化模型
net = nn.Sequential(
nn.Linear(features.shape[-1], 1)
)
loss = nn.MSELoss()
optimizer = optimizer_fn(net.parameters(), **optimizer_hyperparams)
def eval_loss():
return loss(net(features).view(-1), labels).item() / 2
ls = [eval_loss()]
torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)
for _ in range(num_epochs):
start = time.time()
for batch_i, (X, y) in enumerate(data_iter):
# 除以2是为了和train_ch7保持一致, 因为squared_loss中除了2
l = loss(net(X).view(-1), y) / 2

l.backward()
optimizer.step()
if (batch_i + 1) * batch_size % 100 == 0:
ls.append(eval_loss())
# 打印结果和作图
print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))
d2l.set_figsize()
d2l.plt.plot(np.linspace(0, num_epochs, len(ls)), ls)
d2l.plt.xlabel('epoch')
d2l.plt.ylabel('loss')

### 4.1 梯度下降训练结果

train_pytorch_ch7(optim.SGD, {"lr": 0.05}, features, labels, batch_size=1500, num_epochs=6)

loss: 0.701703, 0.013035 sec per epoch

### 4.2 随机梯度下降将结果

train_pytorch_ch7(optim.SGD, {"lr": 0.05}, features, labels, batch_size=1, num_epochs=2)

loss: 0.288860, 0.586868 sec per epoch

### 4.3 小批量随机梯度下降结果

train_pytorch_ch7(optim.SGD, {"lr": 0.05}, features, labels, batch_size=10, num_epochs=2)

loss: 0.242063, 0.075203 sec per epoch

## 5. 总结

• 小批量随机梯度每次随机均匀采样一个小批量的训练样本来计算梯度。
• 通常，小批量随机梯度在每个迭代周期的耗时介于梯度下降和随机梯度下降的耗时之间。

|
7天前
|

【7月更文挑战第8天】掌握Python算法三剑客：分治、贪心、动态规划。分治如归并排序，将大问题拆解递归解决；贪心策略在每步选最优解，如高效找零；动态规划利用子问题解，避免重复计算，解决最长公共子序列问题。实例展示，助你轻松驾驭算法！**
17 3
|
20天前
|

PyTorch框架和MNIST数据集
6月更文挑战20天
|
4天前
|

【7月更文挑战第11天】快速排序是编程基础，以O(n log n)时间复杂度和原址排序著称。其核心是“分而治之”，通过选择基准元素分割数组并递归排序两部分。优化包括：选择中位数作基准、尾递归优化、小数组用简单排序。以下是一个考虑优化的Python实现片段，展示了随机基准选择。通过实践和优化，能提升算法技能。**
8 3
|
27天前
|

Inception v3算法的实战与解析
Inception v3算法的实战与解析
30 3
|
26天前
|

60 1
|
5天前
|

Java面试题：Java内存探秘与多线程并发实战，Java内存模型及分区：理解Java堆、栈、方法区等内存区域的作用，垃圾收集机制：掌握常见的垃圾收集算法及其优缺点
Java面试题：Java内存探秘与多线程并发实战，Java内存模型及分区：理解Java堆、栈、方法区等内存区域的作用，垃圾收集机制：掌握常见的垃圾收集算法及其优缺点
8 0
|
9天前
|

14 0
|
1月前
|

97 0
|
14天前
|

【机器学习】CART决策树算法的核心思想及其大数据时代银行贷款参考案例——机器认知外界的重要算法
【机器学习】CART决策树算法的核心思想及其大数据时代银行贷款参考案例——机器认知外界的重要算法
22 0
|
21天前
|

Inception v3算法的实战与解析
Inception v3算法的实战与解析
15 0