⭐本文内容:多变量线性回归数学推导,梯度下降,基于Pytorch的代码实现
💌参考链接:
3.3. 线性回归的简洁实现 — 动手学深度学习 2.0.0-beta0 documentation (d2l.ai)
基本原理
梯度下降
1.步骤
- 挑选一个初始值 w 0
- 重复迭代参数:t= 1,2,3
- 沿梯度方向将增加损失函数值
- 🍔学习率 n:步长的参数
2.小批量梯度下降
小批量梯度下降是深度学习默认的求解方法
为了节约训练的时间和数据,我们可以随机抽取 b 个样本
i 1 , i 2 , i 3 , … … … i b
以便近似损失
- 🍟b 是选取的训练样本多少,不宜过大或过小
线性回归实现
import torch import random import numpy as np import matplotlib_inline.backend_inline from matplotlib import pyplot as plt
1.生成数据集
定义权重和偏差
num_inputs = 2 #输入两个参数:w、b num_examples = 1000 #样本个数 true_w = [2, -3.4] true_b = 4.2
随机生成正态分布的输入 fatures,feature包括两个影响因素(参数),w和b
features = torch.tensor(np.random.normal(0, 1, (num_examples,num_inputs)), dtype=torch.float)
根据输入 features ,生成输出 labels
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] +true_b labels = labels+torch.tensor(np.random.normal(0, 0.01,size = labels.size()), dtype=torch.float) #在真实数据的基础上,加上噪声
2.数据可视化
def use_svg_display(): # 用⽮量图显示 matplotlib_inline.backend_inline.set_matplotlib_formats('svg') def set_figsize(figsize=(3.5, 2.5)): use_svg_display() # 设置图的尺⼨ plt.rcParams['figure.figsize'] = figsize set_figsize() plt.scatter(features[:, 1].numpy(), labels.numpy(), 1);
3.读取数据
PyTorch提供了了 data 包来读取数据
batch_size = 10,随机读取包含10个数据样本的小批量
import torch.utils.data as Data batch_size = 10
将训练数据的特征和标签组合
features,labels
(tensor([[ 2.1282, -0.8920], [-1.0093, 0.2924], [ 1.5258, -0.5783], ..., [-1.6918, -0.7286], [ 0.2524, 1.0183], [-0.7605, -1.3882]]), tensor([ 1.1486e+01, 1.1739e+00, 9.2210e+00, -3.1652e+00, 1.3957e+00, 1.4540e+00, 1.0448e+01, -3.0799e+00, -6.0249e+00, 6.2256e+00, 1.8337e+00, -2.2532e+00, 9.8570e+00, 4.3742e+00, 1.9240e+00, 9.9293e+00, 4.0584e+00, 8.3627e+00, 1.0823e+00, 7.2722e-01, -7.6379e-01, 5.2927e+00, 6.2904e+00, 5.2404e+00, 3.8104e+00, 6.5634e+00, 4.5559e-01, 1.1481e+00, 6.5535e-01, 5.5602e+00, -4.5899e-01, -4.4074e+00, 3.7406e+00, 7.3968e+00, 3.1700e+00, 8.9007e+00, 9.8127e+00, 1.3216e+00, 6.4583e+00, 8.1877e+00, 1.1374e+00, 1.4746e+00, 6.3436e+00, 3.4523e+00, 7.4513e+00, -5.2496e+00, 2.1701e+00, 6.2465e+00, 3.6356e-02, 3.9021e+00, 9.7487e+00, 1.1606e+01, -5.7617e+00, 1.1398e+00, 2.4662e+00, 6.8591e+00, 2.6556e+00, 1.3848e+00, 8.7197e+00, 8.0591e+00, 3.8578e+00, -3.5013e+00, 2.3531e+00, 5.4477e+00, 1.1095e+00, 3.5379e+00, 4.1187e+00, 5.8690e+00, 8.4104e-01, 3.7816e+00, 3.0125e+00, 2.8499e+00, 6.8824e+00, -1.2200e+00, -4.0836e-01, 7.7109e-01, -2.9056e+00, 5.3143e+00, 5.3520e+00, 8.4585e+00, 6.2324e+00, 1.2647e+01, 7.2575e+00, 6.2891e+00, 4.9514e+00, 6.2109e+00, 5.0092e+00, -9.9337e-01, 5.1607e+00, 6.7718e+00, 9.1802e+00, 6.9370e+00, 2.4028e+00, 4.2869e+00, 3.7471e+00, 8.8147e+00, 6.0244e+00, 6.3032e+00, 2.6607e+00, 3.3901e+00, 6.7301e+00, 5.5077e+00, 1.7717e+00, 6.5372e-01, 1.0034e+00, 2.8830e+00, 1.1419e+01, 1.1749e+00, 1.0106e-02, 8.7469e+00, 7.0206e+00, 1.5480e+00, 3.8976e+00, 4.8662e+00, -1.1577e+00, -3.9768e+00, 2.6253e+00, 5.3503e+00, 4.5756e+00, 1.3480e+00, 5.1243e+00, 8.4004e+00, 3.4222e+00, 6.4114e+00, 5.3183e-01, 9.8974e+00, 2.7052e+00, 3.0212e+00, 2.8368e+00, 8.3128e+00, 4.6546e+00, -4.9879e+00, -4.3837e+00, 2.0918e+00, 4.9817e+00, -2.1596e+00, 1.6440e+00, -2.1060e+00, 6.8430e+00, 2.1973e+00, 4.6677e+00, 5.2363e+00, -3.8503e+00, 9.0410e+00, 1.0080e+01, 5.9184e+00, 1.8692e+01, 4.9923e+00, 6.0264e+00, 8.2590e-01, 8.0111e+00, 1.3049e+01, 1.0785e+01, 7.4764e+00, 3.9071e+00, 1.5800e+00, 7.7218e-01, -9.2418e+00, 4.4402e+00, 5.6216e+00, 6.0654e+00, 3.6753e+00, 1.1367e+01, 7.0856e+00, 2.8123e+00, 3.8476e+00, 5.7843e+00, -2.6613e+00, 8.5492e+00, 1.0268e+01, 3.4798e+00, 6.6199e+00, 1.2376e+01, 5.1228e+00, 1.3886e+00, 2.0064e+00, -2.6576e+00, 6.8330e+00, 3.6461e+00, 6.6444e+00, 1.3514e+01, 6.6098e+00, 1.6626e+00, 7.5031e+00, 3.7306e+00, 4.1361e+00, -1.0839e+00, 4.7839e+00, 1.1011e+00, 9.5930e+00, 9.0655e+00, 6.2866e+00, -3.4704e-02, 2.1236e+00, 3.6567e+00, 1.1865e+01, 5.2644e+00, 3.9271e+00, -3.2964e+00, -1.0615e+00, 7.9056e+00, 3.3556e+00, 6.9934e+00, 3.8433e+00, 5.2684e+00, 1.0572e+01, 5.6345e+00, 5.2405e+00, 7.0610e+00, 8.5568e+00, 9.2425e+00, 3.1511e+00, 1.2075e+01, -1.6131e+00, 2.7143e+00, 1.2283e+01, 2.9278e+00, 1.6995e+00, 4.1698e+00, 6.1994e+00, -2.4863e+00, 1.0766e+01, 8.6130e+00, -1.4727e+00, 9.1569e+00, 2.3191e+00, 9.3571e+00, 3.9649e+00, 6.5649e+00, 5.1595e-01, 4.0500e+00, 4.2225e+00, 8.2217e+00, -6.3868e-01, 7.4402e+00, -2.4692e+00, -3.7820e+00, 1.8397e+00, -3.7670e-01, -2.8920e+00, 3.5418e+00, -1.1415e+00, 8.7870e+00, 4.9213e+00, 9.9247e+00, 6.6192e+00, 1.7358e+00, -3.6307e-01, -3.6719e+00, 4.1485e+00, 5.8974e+00, 5.5616e+00, 8.7071e+00, 6.3648e+00, 1.2180e+01, 1.5238e+00, 7.0139e+00, 2.3830e+00, 4.3800e-03, -1.9390e+00, 4.2378e+00, 1.9782e+00, 7.2340e+00, 6.6344e+00, 3.1206e+00, 9.1773e+00, 1.0010e+00, 6.2425e-01, 2.2324e+00, 1.4394e+00, -1.0583e+00, 1.1862e+00, 3.3771e+00, 6.6086e+00, 1.5868e+00, 3.7503e+00, 1.9657e+00, 5.8942e+00, 5.5674e+00, 9.7752e+00, 3.9570e+00, 9.6817e+00, 5.5203e+00, 3.2171e+00, 2.0135e+00, 1.9813e+00, 7.1679e+00, 1.2223e+00, 6.4231e+00, 3.3712e+00, 6.1226e+00, 7.3023e+00, 7.1456e+00, 1.1152e+00, 4.3584e+00, -2.1064e+00, 2.5603e+00, 9.0358e+00, 1.0727e+01, 5.5608e+00, 2.2114e+00, 1.9765e+00, 5.3305e+00, 1.1187e+00, 4.9850e+00, 6.4261e+00, 5.4080e+00, 1.8553e+00, 5.4068e+00, 1.3708e+01, 2.3931e+00, 4.4118e+00, 3.2059e+00, -3.6586e-01, 1.8181e+00, -3.6866e-02, 4.8933e+00, 5.8703e+00, 8.9500e+00, 5.1381e+00, 5.0268e+00, 6.1895e+00, 2.0560e+01, 2.9177e+00, -1.1442e+00, -4.0448e+00, 3.6746e+00, 4.4213e+00, 3.9709e+00, 1.0074e+01, 7.8624e+00, 4.4473e+00, 3.1089e+00, -2.1261e+00, 4.5493e+00, 6.1902e+00, 3.9441e-01, 6.3312e+00, 2.6792e+00, 2.2669e+00, 5.3156e+00, 3.5936e-01, 3.0885e+00, 4.6025e+00, 3.0167e+00, -4.9812e-01, -5.2459e-02, 8.6781e+00, 6.2955e+00, 1.1866e+01, 3.0411e+00, 1.7559e-01, 5.5648e+00, 6.2147e+00, 7.5455e+00, 9.4367e+00, 8.8933e+00, 2.5610e+00, 4.3828e+00, 3.7813e+00, 2.0521e+00, 2.8338e+00, 1.0147e+00, 1.6295e+00, 1.7991e+00, 4.6832e+00, 5.6675e+00, -4.6209e-01, 3.3670e+00, 2.1762e+00, 5.9204e+00, 6.4009e+00, 3.7618e+00, 1.2200e+01, 2.5451e+00, 3.5890e-01, -4.5513e-01, 5.3301e+00, -2.5259e-01, 6.0445e+00, 7.6773e+00, -8.3305e+00, 5.9668e+00, 5.1585e+00, 1.4224e+00, 7.0593e+00, -7.1631e-01, 6.0731e+00, 5.7953e+00, -1.9298e-01, 2.0513e+00, 2.5159e+00, 7.1324e+00, 1.2148e+00, 2.3089e+00, 1.5061e+01, 1.3625e+00, 2.6933e+00, 6.7279e+00, -2.5925e+00, 7.0981e+00, 4.9872e+00, 4.2145e+00, 8.7688e+00, -2.2502e+00, -3.1313e-01, 1.3205e+01, 2.0962e+00, -8.5085e-01, 2.3345e+00, 5.9682e+00, 1.8128e+00, 1.2720e+00, 1.1814e+01, 1.0851e+01, 1.6860e+00, 7.1576e+00, 5.5403e+00, 1.5703e+01, -2.4011e-01, 4.8863e+00, 9.8978e+00, 2.8947e+00, 4.4821e+00, 3.8706e+00, 8.4014e+00, 2.9715e+00, 6.0662e+00, 2.5023e+00, -3.0358e+00, -2.4967e+00, 5.8111e+00, 2.3611e+00, -3.1542e+00, 9.3606e-01, 3.6221e+00, 2.1523e+00, 3.6833e+00, 8.0808e+00, -1.1753e+00, 1.1207e+01, -6.3684e-01, -2.1802e+00, 1.9406e+00, 6.2395e+00, 7.5898e+00, -3.3976e-01, 1.9477e+00, 3.8777e+00, 3.8332e+00, -6.8939e-01, 6.6390e+00, 4.8504e+00, 6.6808e+00, -1.1883e+00, 3.3064e+00, -2.6315e-01, 4.2325e+00, 1.6865e+00, 3.3041e+00, 6.7421e+00, 5.4875e+00, 1.1149e+00, -2.7828e-01, 5.4725e+00, -7.7733e-01, 4.5353e+00, 1.4220e+00, 1.9443e+00, 7.7243e-01, 9.0265e+00, 9.1433e+00, 8.8006e+00, 3.8516e+00, 9.9867e+00, -8.8891e-01, 3.2137e+00, 6.7171e+00, 6.9532e+00, 3.3229e+00, 1.1517e+01, 5.6163e+00, 3.1043e+00, 7.4874e+00, 3.2463e+00, 7.3695e-01, 4.0015e+00, 8.1072e+00, 8.2856e+00, -7.8871e-01, 3.3036e+00, 4.1132e+00, 1.2551e+00, 4.6967e+00, 5.6785e+00, -4.4015e+00, 1.0862e+01, 7.1842e+00, 8.8019e-01, 8.8348e-01, -2.2253e+00, 3.9538e+00, 4.5217e+00, 6.6754e+00, 4.9789e+00, 3.1611e+00, 1.6457e+00, 1.2445e+00, 9.4860e+00, 3.2143e+00, 3.6886e+00, 3.4686e+00, 9.2762e+00, 4.2195e+00, 4.7956e+00, 7.1322e+00, 9.2499e+00, 5.1920e+00, 9.0004e+00, 4.3371e+00, 7.6478e-01, 6.4798e+00, 4.1674e+00, 2.8258e-01, 1.6894e+00, 2.6807e+00, 3.8514e+00, 8.2753e+00, 1.9894e+00, 5.3953e+00, 6.0704e+00, 7.1677e+00, 1.0194e+01, 8.0606e+00, 1.0083e+00, 7.1323e+00, 1.4029e+00, 2.9510e+00, 1.6452e+00, 2.8926e+00, -4.4551e+00, 2.2231e+00, 6.1967e+00, -7.6643e-01, 5.5426e+00, 7.4233e+00, 4.7071e+00, -3.0313e-01, 6.5572e+00, 6.1507e+00, -3.8997e-01, 1.8177e+00, 4.6363e+00, 6.8179e+00, -4.7610e+00, 6.2758e+00, 1.1724e+01, 3.0298e+00, 6.9504e+00, 5.0827e+00, 6.1753e+00, 3.7896e+00, 4.8787e+00, 1.4354e+00, 4.0417e+00, 8.5842e+00, 2.6213e+00, 5.0641e+00, 5.3131e-02, 1.2046e+01, 1.4526e+01, 4.0456e+00, 1.9023e-01, -6.0494e-01, 6.0688e+00, 4.5011e+00, 1.1216e+01, 3.7415e-01, 5.9470e+00, 8.2929e+00, 5.5519e+00, 3.3750e+00, 4.3696e+00, 3.8813e+00, 7.8028e+00, -6.4641e+00, -9.2145e-01, 2.0501e+00, 2.1780e+00, -3.8143e+00, 3.5393e+00, 3.0591e+00, 3.6973e+00, 4.7163e+00, 7.1230e+00, -7.3374e-01, -1.9185e+00, 6.5769e+00, 2.8983e+00, -6.3834e-01, 4.6760e+00, 6.7897e+00, 2.0130e+00, 9.5432e+00, 6.0370e+00, 3.0305e+00, 6.5129e+00, -6.3260e-01, 6.3626e+00, 7.2939e+00, 9.9977e+00, 9.9677e+00, 6.6315e+00, -1.6937e+00, -7.7215e-01, 3.6396e+00, 7.9828e-01, 8.3489e+00, 1.4208e+00, 8.0603e-01, 1.1263e+00, -1.0269e+00, 6.2334e+00, 2.4863e+00, 6.9147e+00, 9.9371e+00, 1.0550e+00, 7.1006e+00, -1.2406e+00, 2.1380e+00, 4.1862e+00, 6.3072e+00, 6.5124e+00, 1.2564e+01, -2.1824e+00, 9.7297e+00, 3.1217e+00, 2.9313e+00, 6.0977e+00, 5.8781e-01, 5.3255e+00, 9.1447e-01, 8.6587e+00, -4.1399e+00, 3.8851e+00, 1.1139e+01, 2.9723e+00, 9.2591e+00, 7.2289e+00, -4.3935e+00, 4.1558e+00, 9.8247e+00, 1.0492e+01, 8.8529e+00, 3.4140e+00, 1.0711e+01, 1.2701e+01, 1.4397e+00, 1.4085e+00, 4.3338e+00, 2.6513e+00, 2.3139e+00, 6.7066e+00, 1.1019e+01, 4.3761e+00, 2.9457e+00, -1.6806e+00, 1.1144e+01, 6.7122e+00, 2.5283e+00, 3.9789e+00, 3.0566e-01, 3.1576e+00, -1.0638e+00, 5.3148e+00, 8.6122e-01, 9.8231e+00, -1.7613e-01, 7.0995e+00, 6.3133e+00, 4.8049e+00, -1.4338e+00, 2.9295e-01, 7.9991e+00, 4.0050e+00, 4.4134e+00, 9.9134e+00, 6.9708e+00, 4.1459e+00, 5.1887e+00, 7.1081e+00, -3.4246e+00, 7.5349e+00, 3.7021e+00, 3.6239e+00, 1.0308e+01, 4.3593e+00, 6.3187e+00, 6.4794e+00, 1.4713e+01, -2.8738e+00, 5.1832e+00, 6.5156e+00, 1.1308e+00, 3.0447e+00, 7.3644e-01, 1.3814e+00, 1.0465e+01, 1.3175e+01, -1.2418e+00, 3.7152e+00, 4.1893e+00, 8.6820e-01, 7.6717e+00, 4.2367e+00, 5.7733e+00, 4.8947e+00, 8.6885e+00, 5.7364e+00, 5.9605e+00, 4.0355e+00, 2.8231e+00, 4.3217e+00, -5.2753e+00, 7.0826e+00, 9.9783e+00, -1.8532e+00, 7.9443e+00, 3.8591e+00, 9.1033e+00, -3.5306e+00, 5.0788e+00, 8.3283e+00, 3.4355e+00, 6.3829e+00, 5.2743e-01, -1.2355e+00, 5.8636e+00, 2.2426e+00, 1.7321e+00, 1.2331e+00, 8.1444e+00, 9.1087e-01, 6.5348e+00, 8.8130e+00, 2.4073e+00, 9.1603e+00, 5.4269e-02, 8.6125e-01, 8.5849e+00, 5.8992e+00, 1.0239e+00, 5.5160e-01, 9.0970e+00, 7.4586e+00, 4.4703e+00, 7.9342e+00, -6.7836e+00, 3.3555e+00, 5.2390e+00, 1.1013e+00, 3.4443e+00, 1.3650e+01, 4.1885e+00, 3.4643e+00, 4.0415e+00, 7.5722e+00, 5.1383e+00, 2.3401e+00, 5.2583e+00, 4.1095e+00, 8.1516e-01, 4.7037e+00, 2.3122e-02, -1.5872e+00, 2.2099e+00, 5.2872e+00, 8.3151e+00, 1.6701e+00, 3.4691e+00, 4.7135e+00, 7.3734e+00, -2.1556e+00, 1.1943e+01, -4.8788e-01, 2.6438e+00, 7.0337e+00, 6.4599e+00, 4.1977e+00, 6.7508e+00, -5.6997e+00, 7.5895e+00, -5.8350e+00, 2.6028e+00, 2.4710e+00, 2.9316e+00, 5.5433e+00, 3.2023e+00, 6.8027e+00, 2.1754e+00, 5.1678e+00, 7.6736e+00, 3.1644e-01, 7.4585e+00, 6.7207e+00, 2.4918e+00, 5.2083e+00, 4.3911e+00, 2.1199e+00, 7.3935e+00, 3.0019e+00, 5.1438e+00, 1.6835e+00, 1.7025e+01, 6.4380e+00, 3.9269e+00, 1.3423e+00, 1.0056e+01, 6.4518e+00, 1.1786e+01, 1.4369e+01, 5.9835e+00, -4.0346e+00, 6.7038e+00, 6.3128e+00, 2.5318e+00, 5.7233e-01, -3.8579e+00, 3.5814e+00, 3.6234e+00, 3.7589e+00, 3.2349e+00, 9.7685e+00, -5.3370e+00, 3.0212e-01, 7.4470e+00, 7.9035e+00, 3.3017e+00, 4.4283e+00, -1.1999e+00, 3.4110e+00, -7.9651e-01, 3.2304e+00, -3.5655e+00, -5.1933e+00, 1.2617e+01, 3.7052e+00, 6.2573e+00, 5.5877e+00, -5.4492e-01, 4.6039e+00, 1.4071e+00, 1.0263e+01, 1.5162e+00, 1.0058e+01, 1.9859e+00, 6.4663e+00, 5.1987e-01, 6.8607e+00, 6.7446e+00, 1.7579e+00, -1.7222e+00, -1.0771e+00, 7.7494e-01, -1.9333e-02, 8.8316e+00, -4.7399e+00, 4.9970e+00, 9.0686e+00, -6.3899e-01, 3.5481e+00, 7.5947e+00, 4.1835e+00, 4.8132e+00, 5.4006e+00, 7.1721e+00, 1.3880e+01, 8.8979e+00, 5.6534e+00, 2.4463e+00, -8.5860e+00, 5.2726e+00, -1.9059e+00, 9.7225e+00, 6.1324e+00, 2.7480e+00, 9.5607e+00, -5.1262e-01, -5.2414e-01, 5.8667e+00, 5.3955e+00, 7.5670e+00, 7.3300e+00, 8.9820e+00, 4.0835e+00, 7.5790e+00, 8.7937e+00, 4.8071e+00, 3.3220e+00, 1.1934e+01, 9.3510e+00, 5.2204e+00, 4.7458e+00, 9.7410e+00, 5.5066e-01, 2.2319e+00, 4.5363e+00, 8.8502e+00, -2.3816e+00, 1.1200e+01, 8.4412e+00, 6.6048e+00, 1.0004e+00, 8.3695e+00, 7.5421e+00, 5.3460e-01, 9.5623e+00, 3.5176e+00, 1.5785e+00, 6.8748e+00, -2.5705e-01, 3.8226e+00, 1.2004e+01, 5.6965e+00, 5.1880e+00, 9.5765e+00, 7.1425e+00, 3.8174e+00, 2.7385e+00, 1.5513e+00, -2.4735e+00, 7.1869e+00, 2.1463e+00, 2.9772e+00, -2.1224e+00, 6.4781e+00, 3.9532e-01, -2.8792e+00, -1.8212e+00, 1.3565e+01, 8.5305e+00, 7.3443e+00, 3.3902e+00, 7.2337e+00, 8.1134e+00, 1.9861e+00, 7.0784e-01, 5.0915e+00, 2.6137e+00, 2.9595e+00, 6.4367e+00, 4.6851e+00, -3.0581e+00, 6.5745e+00, -1.3078e+00, 2.3855e+00, 6.5987e+00, 3.0967e+00, 3.3623e+00, 1.5926e+00, 9.8759e+00, 3.0744e+00, 2.4427e-01, 6.9704e+00, 3.6868e+00, -3.4440e+00, 4.8641e+00, 6.7563e+00, 8.7857e+00, 7.6898e+00, 1.3370e+01, 3.7387e+00, -2.7800e+00, -4.4384e+00, 4.4311e+00, 4.6384e+00, 2.8387e+00, 4.7712e+00, 4.1730e+00, 2.1846e+00, 3.8315e+00, 7.0497e+00, -1.2256e+00, 4.5546e+00, 3.3716e+00, 1.2256e-01, 5.6450e+00, 9.6056e+00, -5.0523e-01, 2.6049e+00, -7.1354e-01, 5.1554e+00, 5.3114e+00, 8.1763e+00, -9.1869e+00, -2.5365e+00, 4.2363e+00, 3.2928e+00, 1.2413e+00, 7.4122e+00]))
将训练数据的特征和标签组合
dataset = Data.TensorDataset(features,labels) dataset
<torch.utils.data.dataset.TensorDataset at 0x7f1cf8f69250>
随机读取小批量
data_iter = Data.DataLoader(dataset, batch_size, shuffle=True) data_iter
<torch.utils.data.dataloader.DataLoader at 0x7f1cf8f69d90>
打印第一个小批量样本数据
for x, y in data_iter: print(x, y) break
tensor([[ 1.8036, 0.2046], [-0.1144, 1.2728], [ 0.4237, -0.0737], [ 0.9859, -0.9073], [-1.9821, 2.7871], [-0.7970, -1.2866], [ 0.1646, -0.6856], [ 0.7002, 1.0762], [-0.9576, -1.3074], [ 0.7070, -0.1870]]) tensor([ 7.1322, -0.3659, 5.2927, 9.2499, -9.2418, 6.9704, 6.8748, 1.9477, 6.7563, 6.2395])
4.定义模型
导⼊ torch.nn 模块
import torch.nn as nn
用 nn.Sequential 来更加⽅便地搭建网络, Sequential 是一个有序的容器,网络层将按照在传入 Sequential 的顺序依次被添加到计算图中
net = nn.Sequential( nn.Linear(num_inputs, 1) # 此处还可以传⼊入其他层 ) print(net) print(net[0])
Sequential( (0): Linear(in_features=2, out_features=1, bias=True) ) Linear(in_features=2, out_features=1, bias=True)
可以通过 net.parameters() 来查看模型所有的可学习参数,此函数将返回一个生成器
for param in net.parameters(): print(param)
Parameter containing: tensor([[0.6623, 0.0845]], requires_grad=True) Parameter containing: tensor([0.1282], requires_grad=True)
5.初始化参数
from torch.nn import init
我们通过 init.normal_ 将权重参数每个元素初始化为随机采样于均值为0、标准差为0.01的正态分布。偏差会初始化为零。
init.normal_(net[0].weight,mean=0,std=0.01) init.constant_(net[0].bias,val=0) #也可以直接修改bias的data net[0].bias.data.fill_(0)
tensor([0.])
6.定义损失函数
使用均方误差损失作为模型的损失函数
loss = nn.MSELoss()
7.定义优化算法
指定学习率为0.03的⼩小批量量随机梯度下降(SGD)为优化算法
import torch.optim as optim optimizer = optim.SGD( #如果对某个参数不指定学习率,就使用最外层的默认学习率 net.parameters(),lr=0.03 ) print(optimizer)
SGD ( Parameter Group 0 dampening: 0 lr: 0.03 momentum: 0 nesterov: False weight_decay: 0 )
构建新的optimizer,动态调整学习率
for param_group in optimizer.param_groups: param_group['lr'] *= 0.1 #学习率为之前的0.1倍
8.训练模型
num_epochs = 10 #迭代的次数,次数越大最后的准确率越高 for epoch in range(1,num_epochs+1): for x,y in data_iter: output = net(x) l = loss(output,y.view(-1,1)) optimizer.zero_grad() #梯度清零,等价于net.zero_grad() l.backward() optimizer.step() print('epoch: %d,loss: %f' %(epoch,l.item()))
epoch: 1,loss: 15.321363 epoch: 2,loss: 6.421220 epoch: 3,loss: 0.666509 epoch: 4,loss: 0.269027 epoch: 5,loss: 0.071433 epoch: 6,loss: 0.027302 epoch: 7,loss: 0.005877 epoch: 8,loss: 0.001547 epoch: 9,loss: 0.000483 epoch: 10,loss: 0.000219
训练值和真实值对比
dense = net[0] print(true_w, dense.weight) print(true_b, dense.bias)
[2, -3.4] Parameter containing: tensor([[ 1.9972, -3.3939]], requires_grad=True) 4.2 Parameter containing: tensor([4.1887], requires_grad=True)