【深度学习02】 多变量线性回归

简介: 为了节约训练的时间和数据,我们可以随机抽取 b 个样本

⭐本文内容:多变量线性回归数学推导,梯度下降,基于Pytorch的代码实现


💌参考链接:


ML1 单变量线性回归_什么都只会一点的博客-CSDN博客

dl3.2线性回归 | Kaggle

3.3. 线性回归的简洁实现 — 动手学深度学习 2.0.0-beta0 documentation (d2l.ai)


基本原理


image.png


5cdc28082f384ec19b0496d6cb8cc5d9.png


梯度下降


1.步骤


  • 挑选一个初始值 w 0


  • 重复迭代参数:t= 1,2,3


image.png


  • 沿梯度方向将增加损失函数值


image.png


  • 🍔学习率 n:步长的参数


202205011609631.png


2.小批量梯度下降


小批量梯度下降是深度学习默认的求解方法


为了节约训练的时间和数据,我们可以随机抽取 b 个样本


i 1 , i 2 , i 3 , … … … i b


以便近似损失


image.png


  • 🍟b 是选取的训练样本多少,不宜过大或过小


线性回归实现


import torch
import random
import numpy as np
import matplotlib_inline.backend_inline
from matplotlib import pyplot as plt


1.生成数据集


定义权重和偏差


num_inputs = 2  #输入两个参数:w、b
num_examples = 1000 #样本个数
true_w = [2, -3.4]
true_b = 4.2


随机生成正态分布的输入 fatures,feature包括两个影响因素(参数),w和b


features = torch.tensor(np.random.normal(0, 1, (num_examples,num_inputs)), dtype=torch.float)


根据输入 features ,生成输出 labels


labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] +true_b
labels = labels+torch.tensor(np.random.normal(0, 0.01,size = labels.size()), dtype=torch.float)  #在真实数据的基础上,加上噪声


2.数据可视化


def use_svg_display():
    # 用⽮量图显示
    matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
def set_figsize(figsize=(3.5, 2.5)):
    use_svg_display()
    # 设置图的尺⼨
    plt.rcParams['figure.figsize'] = figsize
set_figsize()
plt.scatter(features[:, 1].numpy(), labels.numpy(), 1); 


3a2d0562acd3a16080b38de970c0ac9d.png


3.读取数据


PyTorch提供了了 data 包来读取数据


batch_size = 10,随机读取包含10个数据样本的小批量


import torch.utils.data as Data
batch_size = 10


将训练数据的特征和标签组合


features,labels


(tensor([[ 2.1282, -0.8920],
         [-1.0093,  0.2924],
         [ 1.5258, -0.5783],
         ...,
         [-1.6918, -0.7286],
         [ 0.2524,  1.0183],
         [-0.7605, -1.3882]]),
 tensor([ 1.1486e+01,  1.1739e+00,  9.2210e+00, -3.1652e+00,  1.3957e+00,
          1.4540e+00,  1.0448e+01, -3.0799e+00, -6.0249e+00,  6.2256e+00,
          1.8337e+00, -2.2532e+00,  9.8570e+00,  4.3742e+00,  1.9240e+00,
          9.9293e+00,  4.0584e+00,  8.3627e+00,  1.0823e+00,  7.2722e-01,
         -7.6379e-01,  5.2927e+00,  6.2904e+00,  5.2404e+00,  3.8104e+00,
          6.5634e+00,  4.5559e-01,  1.1481e+00,  6.5535e-01,  5.5602e+00,
         -4.5899e-01, -4.4074e+00,  3.7406e+00,  7.3968e+00,  3.1700e+00,
          8.9007e+00,  9.8127e+00,  1.3216e+00,  6.4583e+00,  8.1877e+00,
          1.1374e+00,  1.4746e+00,  6.3436e+00,  3.4523e+00,  7.4513e+00,
         -5.2496e+00,  2.1701e+00,  6.2465e+00,  3.6356e-02,  3.9021e+00,
          9.7487e+00,  1.1606e+01, -5.7617e+00,  1.1398e+00,  2.4662e+00,
          6.8591e+00,  2.6556e+00,  1.3848e+00,  8.7197e+00,  8.0591e+00,
          3.8578e+00, -3.5013e+00,  2.3531e+00,  5.4477e+00,  1.1095e+00,
          3.5379e+00,  4.1187e+00,  5.8690e+00,  8.4104e-01,  3.7816e+00,
          3.0125e+00,  2.8499e+00,  6.8824e+00, -1.2200e+00, -4.0836e-01,
          7.7109e-01, -2.9056e+00,  5.3143e+00,  5.3520e+00,  8.4585e+00,
          6.2324e+00,  1.2647e+01,  7.2575e+00,  6.2891e+00,  4.9514e+00,
          6.2109e+00,  5.0092e+00, -9.9337e-01,  5.1607e+00,  6.7718e+00,
          9.1802e+00,  6.9370e+00,  2.4028e+00,  4.2869e+00,  3.7471e+00,
          8.8147e+00,  6.0244e+00,  6.3032e+00,  2.6607e+00,  3.3901e+00,
          6.7301e+00,  5.5077e+00,  1.7717e+00,  6.5372e-01,  1.0034e+00,
          2.8830e+00,  1.1419e+01,  1.1749e+00,  1.0106e-02,  8.7469e+00,
          7.0206e+00,  1.5480e+00,  3.8976e+00,  4.8662e+00, -1.1577e+00,
         -3.9768e+00,  2.6253e+00,  5.3503e+00,  4.5756e+00,  1.3480e+00,
          5.1243e+00,  8.4004e+00,  3.4222e+00,  6.4114e+00,  5.3183e-01,
          9.8974e+00,  2.7052e+00,  3.0212e+00,  2.8368e+00,  8.3128e+00,
          4.6546e+00, -4.9879e+00, -4.3837e+00,  2.0918e+00,  4.9817e+00,
         -2.1596e+00,  1.6440e+00, -2.1060e+00,  6.8430e+00,  2.1973e+00,
          4.6677e+00,  5.2363e+00, -3.8503e+00,  9.0410e+00,  1.0080e+01,
          5.9184e+00,  1.8692e+01,  4.9923e+00,  6.0264e+00,  8.2590e-01,
          8.0111e+00,  1.3049e+01,  1.0785e+01,  7.4764e+00,  3.9071e+00,
          1.5800e+00,  7.7218e-01, -9.2418e+00,  4.4402e+00,  5.6216e+00,
          6.0654e+00,  3.6753e+00,  1.1367e+01,  7.0856e+00,  2.8123e+00,
          3.8476e+00,  5.7843e+00, -2.6613e+00,  8.5492e+00,  1.0268e+01,
          3.4798e+00,  6.6199e+00,  1.2376e+01,  5.1228e+00,  1.3886e+00,
          2.0064e+00, -2.6576e+00,  6.8330e+00,  3.6461e+00,  6.6444e+00,
          1.3514e+01,  6.6098e+00,  1.6626e+00,  7.5031e+00,  3.7306e+00,
          4.1361e+00, -1.0839e+00,  4.7839e+00,  1.1011e+00,  9.5930e+00,
          9.0655e+00,  6.2866e+00, -3.4704e-02,  2.1236e+00,  3.6567e+00,
          1.1865e+01,  5.2644e+00,  3.9271e+00, -3.2964e+00, -1.0615e+00,
          7.9056e+00,  3.3556e+00,  6.9934e+00,  3.8433e+00,  5.2684e+00,
          1.0572e+01,  5.6345e+00,  5.2405e+00,  7.0610e+00,  8.5568e+00,
          9.2425e+00,  3.1511e+00,  1.2075e+01, -1.6131e+00,  2.7143e+00,
          1.2283e+01,  2.9278e+00,  1.6995e+00,  4.1698e+00,  6.1994e+00,
         -2.4863e+00,  1.0766e+01,  8.6130e+00, -1.4727e+00,  9.1569e+00,
          2.3191e+00,  9.3571e+00,  3.9649e+00,  6.5649e+00,  5.1595e-01,
          4.0500e+00,  4.2225e+00,  8.2217e+00, -6.3868e-01,  7.4402e+00,
         -2.4692e+00, -3.7820e+00,  1.8397e+00, -3.7670e-01, -2.8920e+00,
          3.5418e+00, -1.1415e+00,  8.7870e+00,  4.9213e+00,  9.9247e+00,
          6.6192e+00,  1.7358e+00, -3.6307e-01, -3.6719e+00,  4.1485e+00,
          5.8974e+00,  5.5616e+00,  8.7071e+00,  6.3648e+00,  1.2180e+01,
          1.5238e+00,  7.0139e+00,  2.3830e+00,  4.3800e-03, -1.9390e+00,
          4.2378e+00,  1.9782e+00,  7.2340e+00,  6.6344e+00,  3.1206e+00,
          9.1773e+00,  1.0010e+00,  6.2425e-01,  2.2324e+00,  1.4394e+00,
         -1.0583e+00,  1.1862e+00,  3.3771e+00,  6.6086e+00,  1.5868e+00,
          3.7503e+00,  1.9657e+00,  5.8942e+00,  5.5674e+00,  9.7752e+00,
          3.9570e+00,  9.6817e+00,  5.5203e+00,  3.2171e+00,  2.0135e+00,
          1.9813e+00,  7.1679e+00,  1.2223e+00,  6.4231e+00,  3.3712e+00,
          6.1226e+00,  7.3023e+00,  7.1456e+00,  1.1152e+00,  4.3584e+00,
         -2.1064e+00,  2.5603e+00,  9.0358e+00,  1.0727e+01,  5.5608e+00,
          2.2114e+00,  1.9765e+00,  5.3305e+00,  1.1187e+00,  4.9850e+00,
          6.4261e+00,  5.4080e+00,  1.8553e+00,  5.4068e+00,  1.3708e+01,
          2.3931e+00,  4.4118e+00,  3.2059e+00, -3.6586e-01,  1.8181e+00,
         -3.6866e-02,  4.8933e+00,  5.8703e+00,  8.9500e+00,  5.1381e+00,
          5.0268e+00,  6.1895e+00,  2.0560e+01,  2.9177e+00, -1.1442e+00,
         -4.0448e+00,  3.6746e+00,  4.4213e+00,  3.9709e+00,  1.0074e+01,
          7.8624e+00,  4.4473e+00,  3.1089e+00, -2.1261e+00,  4.5493e+00,
          6.1902e+00,  3.9441e-01,  6.3312e+00,  2.6792e+00,  2.2669e+00,
          5.3156e+00,  3.5936e-01,  3.0885e+00,  4.6025e+00,  3.0167e+00,
         -4.9812e-01, -5.2459e-02,  8.6781e+00,  6.2955e+00,  1.1866e+01,
          3.0411e+00,  1.7559e-01,  5.5648e+00,  6.2147e+00,  7.5455e+00,
          9.4367e+00,  8.8933e+00,  2.5610e+00,  4.3828e+00,  3.7813e+00,
          2.0521e+00,  2.8338e+00,  1.0147e+00,  1.6295e+00,  1.7991e+00,
          4.6832e+00,  5.6675e+00, -4.6209e-01,  3.3670e+00,  2.1762e+00,
          5.9204e+00,  6.4009e+00,  3.7618e+00,  1.2200e+01,  2.5451e+00,
          3.5890e-01, -4.5513e-01,  5.3301e+00, -2.5259e-01,  6.0445e+00,
          7.6773e+00, -8.3305e+00,  5.9668e+00,  5.1585e+00,  1.4224e+00,
          7.0593e+00, -7.1631e-01,  6.0731e+00,  5.7953e+00, -1.9298e-01,
          2.0513e+00,  2.5159e+00,  7.1324e+00,  1.2148e+00,  2.3089e+00,
          1.5061e+01,  1.3625e+00,  2.6933e+00,  6.7279e+00, -2.5925e+00,
          7.0981e+00,  4.9872e+00,  4.2145e+00,  8.7688e+00, -2.2502e+00,
         -3.1313e-01,  1.3205e+01,  2.0962e+00, -8.5085e-01,  2.3345e+00,
          5.9682e+00,  1.8128e+00,  1.2720e+00,  1.1814e+01,  1.0851e+01,
          1.6860e+00,  7.1576e+00,  5.5403e+00,  1.5703e+01, -2.4011e-01,
          4.8863e+00,  9.8978e+00,  2.8947e+00,  4.4821e+00,  3.8706e+00,
          8.4014e+00,  2.9715e+00,  6.0662e+00,  2.5023e+00, -3.0358e+00,
         -2.4967e+00,  5.8111e+00,  2.3611e+00, -3.1542e+00,  9.3606e-01,
          3.6221e+00,  2.1523e+00,  3.6833e+00,  8.0808e+00, -1.1753e+00,
          1.1207e+01, -6.3684e-01, -2.1802e+00,  1.9406e+00,  6.2395e+00,
          7.5898e+00, -3.3976e-01,  1.9477e+00,  3.8777e+00,  3.8332e+00,
         -6.8939e-01,  6.6390e+00,  4.8504e+00,  6.6808e+00, -1.1883e+00,
          3.3064e+00, -2.6315e-01,  4.2325e+00,  1.6865e+00,  3.3041e+00,
          6.7421e+00,  5.4875e+00,  1.1149e+00, -2.7828e-01,  5.4725e+00,
         -7.7733e-01,  4.5353e+00,  1.4220e+00,  1.9443e+00,  7.7243e-01,
          9.0265e+00,  9.1433e+00,  8.8006e+00,  3.8516e+00,  9.9867e+00,
         -8.8891e-01,  3.2137e+00,  6.7171e+00,  6.9532e+00,  3.3229e+00,
          1.1517e+01,  5.6163e+00,  3.1043e+00,  7.4874e+00,  3.2463e+00,
          7.3695e-01,  4.0015e+00,  8.1072e+00,  8.2856e+00, -7.8871e-01,
          3.3036e+00,  4.1132e+00,  1.2551e+00,  4.6967e+00,  5.6785e+00,
         -4.4015e+00,  1.0862e+01,  7.1842e+00,  8.8019e-01,  8.8348e-01,
         -2.2253e+00,  3.9538e+00,  4.5217e+00,  6.6754e+00,  4.9789e+00,
          3.1611e+00,  1.6457e+00,  1.2445e+00,  9.4860e+00,  3.2143e+00,
          3.6886e+00,  3.4686e+00,  9.2762e+00,  4.2195e+00,  4.7956e+00,
          7.1322e+00,  9.2499e+00,  5.1920e+00,  9.0004e+00,  4.3371e+00,
          7.6478e-01,  6.4798e+00,  4.1674e+00,  2.8258e-01,  1.6894e+00,
          2.6807e+00,  3.8514e+00,  8.2753e+00,  1.9894e+00,  5.3953e+00,
          6.0704e+00,  7.1677e+00,  1.0194e+01,  8.0606e+00,  1.0083e+00,
          7.1323e+00,  1.4029e+00,  2.9510e+00,  1.6452e+00,  2.8926e+00,
         -4.4551e+00,  2.2231e+00,  6.1967e+00, -7.6643e-01,  5.5426e+00,
          7.4233e+00,  4.7071e+00, -3.0313e-01,  6.5572e+00,  6.1507e+00,
         -3.8997e-01,  1.8177e+00,  4.6363e+00,  6.8179e+00, -4.7610e+00,
          6.2758e+00,  1.1724e+01,  3.0298e+00,  6.9504e+00,  5.0827e+00,
          6.1753e+00,  3.7896e+00,  4.8787e+00,  1.4354e+00,  4.0417e+00,
          8.5842e+00,  2.6213e+00,  5.0641e+00,  5.3131e-02,  1.2046e+01,
          1.4526e+01,  4.0456e+00,  1.9023e-01, -6.0494e-01,  6.0688e+00,
          4.5011e+00,  1.1216e+01,  3.7415e-01,  5.9470e+00,  8.2929e+00,
          5.5519e+00,  3.3750e+00,  4.3696e+00,  3.8813e+00,  7.8028e+00,
         -6.4641e+00, -9.2145e-01,  2.0501e+00,  2.1780e+00, -3.8143e+00,
          3.5393e+00,  3.0591e+00,  3.6973e+00,  4.7163e+00,  7.1230e+00,
         -7.3374e-01, -1.9185e+00,  6.5769e+00,  2.8983e+00, -6.3834e-01,
          4.6760e+00,  6.7897e+00,  2.0130e+00,  9.5432e+00,  6.0370e+00,
          3.0305e+00,  6.5129e+00, -6.3260e-01,  6.3626e+00,  7.2939e+00,
          9.9977e+00,  9.9677e+00,  6.6315e+00, -1.6937e+00, -7.7215e-01,
          3.6396e+00,  7.9828e-01,  8.3489e+00,  1.4208e+00,  8.0603e-01,
          1.1263e+00, -1.0269e+00,  6.2334e+00,  2.4863e+00,  6.9147e+00,
          9.9371e+00,  1.0550e+00,  7.1006e+00, -1.2406e+00,  2.1380e+00,
          4.1862e+00,  6.3072e+00,  6.5124e+00,  1.2564e+01, -2.1824e+00,
          9.7297e+00,  3.1217e+00,  2.9313e+00,  6.0977e+00,  5.8781e-01,
          5.3255e+00,  9.1447e-01,  8.6587e+00, -4.1399e+00,  3.8851e+00,
          1.1139e+01,  2.9723e+00,  9.2591e+00,  7.2289e+00, -4.3935e+00,
          4.1558e+00,  9.8247e+00,  1.0492e+01,  8.8529e+00,  3.4140e+00,
          1.0711e+01,  1.2701e+01,  1.4397e+00,  1.4085e+00,  4.3338e+00,
          2.6513e+00,  2.3139e+00,  6.7066e+00,  1.1019e+01,  4.3761e+00,
          2.9457e+00, -1.6806e+00,  1.1144e+01,  6.7122e+00,  2.5283e+00,
          3.9789e+00,  3.0566e-01,  3.1576e+00, -1.0638e+00,  5.3148e+00,
          8.6122e-01,  9.8231e+00, -1.7613e-01,  7.0995e+00,  6.3133e+00,
          4.8049e+00, -1.4338e+00,  2.9295e-01,  7.9991e+00,  4.0050e+00,
          4.4134e+00,  9.9134e+00,  6.9708e+00,  4.1459e+00,  5.1887e+00,
          7.1081e+00, -3.4246e+00,  7.5349e+00,  3.7021e+00,  3.6239e+00,
          1.0308e+01,  4.3593e+00,  6.3187e+00,  6.4794e+00,  1.4713e+01,
         -2.8738e+00,  5.1832e+00,  6.5156e+00,  1.1308e+00,  3.0447e+00,
          7.3644e-01,  1.3814e+00,  1.0465e+01,  1.3175e+01, -1.2418e+00,
          3.7152e+00,  4.1893e+00,  8.6820e-01,  7.6717e+00,  4.2367e+00,
          5.7733e+00,  4.8947e+00,  8.6885e+00,  5.7364e+00,  5.9605e+00,
          4.0355e+00,  2.8231e+00,  4.3217e+00, -5.2753e+00,  7.0826e+00,
          9.9783e+00, -1.8532e+00,  7.9443e+00,  3.8591e+00,  9.1033e+00,
         -3.5306e+00,  5.0788e+00,  8.3283e+00,  3.4355e+00,  6.3829e+00,
          5.2743e-01, -1.2355e+00,  5.8636e+00,  2.2426e+00,  1.7321e+00,
          1.2331e+00,  8.1444e+00,  9.1087e-01,  6.5348e+00,  8.8130e+00,
          2.4073e+00,  9.1603e+00,  5.4269e-02,  8.6125e-01,  8.5849e+00,
          5.8992e+00,  1.0239e+00,  5.5160e-01,  9.0970e+00,  7.4586e+00,
          4.4703e+00,  7.9342e+00, -6.7836e+00,  3.3555e+00,  5.2390e+00,
          1.1013e+00,  3.4443e+00,  1.3650e+01,  4.1885e+00,  3.4643e+00,
          4.0415e+00,  7.5722e+00,  5.1383e+00,  2.3401e+00,  5.2583e+00,
          4.1095e+00,  8.1516e-01,  4.7037e+00,  2.3122e-02, -1.5872e+00,
          2.2099e+00,  5.2872e+00,  8.3151e+00,  1.6701e+00,  3.4691e+00,
          4.7135e+00,  7.3734e+00, -2.1556e+00,  1.1943e+01, -4.8788e-01,
          2.6438e+00,  7.0337e+00,  6.4599e+00,  4.1977e+00,  6.7508e+00,
         -5.6997e+00,  7.5895e+00, -5.8350e+00,  2.6028e+00,  2.4710e+00,
          2.9316e+00,  5.5433e+00,  3.2023e+00,  6.8027e+00,  2.1754e+00,
          5.1678e+00,  7.6736e+00,  3.1644e-01,  7.4585e+00,  6.7207e+00,
          2.4918e+00,  5.2083e+00,  4.3911e+00,  2.1199e+00,  7.3935e+00,
          3.0019e+00,  5.1438e+00,  1.6835e+00,  1.7025e+01,  6.4380e+00,
          3.9269e+00,  1.3423e+00,  1.0056e+01,  6.4518e+00,  1.1786e+01,
          1.4369e+01,  5.9835e+00, -4.0346e+00,  6.7038e+00,  6.3128e+00,
          2.5318e+00,  5.7233e-01, -3.8579e+00,  3.5814e+00,  3.6234e+00,
          3.7589e+00,  3.2349e+00,  9.7685e+00, -5.3370e+00,  3.0212e-01,
          7.4470e+00,  7.9035e+00,  3.3017e+00,  4.4283e+00, -1.1999e+00,
          3.4110e+00, -7.9651e-01,  3.2304e+00, -3.5655e+00, -5.1933e+00,
          1.2617e+01,  3.7052e+00,  6.2573e+00,  5.5877e+00, -5.4492e-01,
          4.6039e+00,  1.4071e+00,  1.0263e+01,  1.5162e+00,  1.0058e+01,
          1.9859e+00,  6.4663e+00,  5.1987e-01,  6.8607e+00,  6.7446e+00,
          1.7579e+00, -1.7222e+00, -1.0771e+00,  7.7494e-01, -1.9333e-02,
          8.8316e+00, -4.7399e+00,  4.9970e+00,  9.0686e+00, -6.3899e-01,
          3.5481e+00,  7.5947e+00,  4.1835e+00,  4.8132e+00,  5.4006e+00,
          7.1721e+00,  1.3880e+01,  8.8979e+00,  5.6534e+00,  2.4463e+00,
         -8.5860e+00,  5.2726e+00, -1.9059e+00,  9.7225e+00,  6.1324e+00,
          2.7480e+00,  9.5607e+00, -5.1262e-01, -5.2414e-01,  5.8667e+00,
          5.3955e+00,  7.5670e+00,  7.3300e+00,  8.9820e+00,  4.0835e+00,
          7.5790e+00,  8.7937e+00,  4.8071e+00,  3.3220e+00,  1.1934e+01,
          9.3510e+00,  5.2204e+00,  4.7458e+00,  9.7410e+00,  5.5066e-01,
          2.2319e+00,  4.5363e+00,  8.8502e+00, -2.3816e+00,  1.1200e+01,
          8.4412e+00,  6.6048e+00,  1.0004e+00,  8.3695e+00,  7.5421e+00,
          5.3460e-01,  9.5623e+00,  3.5176e+00,  1.5785e+00,  6.8748e+00,
         -2.5705e-01,  3.8226e+00,  1.2004e+01,  5.6965e+00,  5.1880e+00,
          9.5765e+00,  7.1425e+00,  3.8174e+00,  2.7385e+00,  1.5513e+00,
         -2.4735e+00,  7.1869e+00,  2.1463e+00,  2.9772e+00, -2.1224e+00,
          6.4781e+00,  3.9532e-01, -2.8792e+00, -1.8212e+00,  1.3565e+01,
          8.5305e+00,  7.3443e+00,  3.3902e+00,  7.2337e+00,  8.1134e+00,
          1.9861e+00,  7.0784e-01,  5.0915e+00,  2.6137e+00,  2.9595e+00,
          6.4367e+00,  4.6851e+00, -3.0581e+00,  6.5745e+00, -1.3078e+00,
          2.3855e+00,  6.5987e+00,  3.0967e+00,  3.3623e+00,  1.5926e+00,
          9.8759e+00,  3.0744e+00,  2.4427e-01,  6.9704e+00,  3.6868e+00,
         -3.4440e+00,  4.8641e+00,  6.7563e+00,  8.7857e+00,  7.6898e+00,
          1.3370e+01,  3.7387e+00, -2.7800e+00, -4.4384e+00,  4.4311e+00,
          4.6384e+00,  2.8387e+00,  4.7712e+00,  4.1730e+00,  2.1846e+00,
          3.8315e+00,  7.0497e+00, -1.2256e+00,  4.5546e+00,  3.3716e+00,
          1.2256e-01,  5.6450e+00,  9.6056e+00, -5.0523e-01,  2.6049e+00,
         -7.1354e-01,  5.1554e+00,  5.3114e+00,  8.1763e+00, -9.1869e+00,
         -2.5365e+00,  4.2363e+00,  3.2928e+00,  1.2413e+00,  7.4122e+00]))


将训练数据的特征和标签组合


dataset = Data.TensorDataset(features,labels)
dataset


<torch.utils.data.dataset.TensorDataset at 0x7f1cf8f69250>


随机读取小批量


data_iter = Data.DataLoader(dataset, batch_size, shuffle=True)
data_iter


<torch.utils.data.dataloader.DataLoader at 0x7f1cf8f69d90>


打印第一个小批量样本数据


for x, y in data_iter:
    print(x, y)
    break


tensor([[ 1.8036,  0.2046],
        [-0.1144,  1.2728],
        [ 0.4237, -0.0737],
        [ 0.9859, -0.9073],
        [-1.9821,  2.7871],
        [-0.7970, -1.2866],
        [ 0.1646, -0.6856],
        [ 0.7002,  1.0762],
        [-0.9576, -1.3074],
        [ 0.7070, -0.1870]]) tensor([ 7.1322, -0.3659,  5.2927,  9.2499, -9.2418,  6.9704,  6.8748,  1.9477,
         6.7563,  6.2395])


4.定义模型


导⼊ torch.nn 模块


import torch.nn as nn


用 nn.Sequential 来更加⽅便地搭建网络, Sequential 是一个有序的容器,网络层将按照在传入 Sequential 的顺序依次被添加到计算图中


net = nn.Sequential(
    nn.Linear(num_inputs, 1)
    # 此处还可以传⼊入其他层
    )
print(net)
print(net[0])


Sequential(
  (0): Linear(in_features=2, out_features=1, bias=True)
)
Linear(in_features=2, out_features=1, bias=True)


可以通过 net.parameters() 来查看模型所有的可学习参数,此函数将返回一个生成器


for param in net.parameters():
    print(param)


Parameter containing:
tensor([[0.6623, 0.0845]], requires_grad=True)
Parameter containing:
tensor([0.1282], requires_grad=True)


5.初始化参数


from torch.nn import init


我们通过 init.normal_ 将权重参数每个元素初始化为随机采样于均值为0、标准差为0.01的正态分布。偏差会初始化为零。


init.normal_(net[0].weight,mean=0,std=0.01)
init.constant_(net[0].bias,val=0)  #也可以直接修改bias的data
net[0].bias.data.fill_(0)


tensor([0.])


6.定义损失函数


使用均方误差损失作为模型的损失函数


loss = nn.MSELoss()


7.定义优化算法


指定学习率为0.03的⼩小批量量随机梯度下降(SGD)为优化算法


import torch.optim as optim
optimizer = optim.SGD(
    #如果对某个参数不指定学习率,就使用最外层的默认学习率
    net.parameters(),lr=0.03
    )
print(optimizer)


SGD (
Parameter Group 0
    dampening: 0
    lr: 0.03
    momentum: 0
    nesterov: False
    weight_decay: 0
)


构建新的optimizer,动态调整学习率


for param_group in optimizer.param_groups:
    param_group['lr'] *= 0.1  #学习率为之前的0.1倍


8.训练模型


num_epochs = 10  #迭代的次数,次数越大最后的准确率越高
for epoch in range(1,num_epochs+1):
    for x,y in data_iter:
        output = net(x)
        l = loss(output,y.view(-1,1))
        optimizer.zero_grad()  #梯度清零,等价于net.zero_grad()
        l.backward()
        optimizer.step()
    print('epoch: %d,loss: %f' %(epoch,l.item()))


epoch: 1,loss: 15.321363
epoch: 2,loss: 6.421220
epoch: 3,loss: 0.666509
epoch: 4,loss: 0.269027
epoch: 5,loss: 0.071433
epoch: 6,loss: 0.027302
epoch: 7,loss: 0.005877
epoch: 8,loss: 0.001547
epoch: 9,loss: 0.000483
epoch: 10,loss: 0.000219


训练值和真实值对比


dense = net[0]
print(true_w, dense.weight)
print(true_b, dense.bias)


[2, -3.4] Parameter containing:
tensor([[ 1.9972, -3.3939]], requires_grad=True)
4.2 Parameter containing:
tensor([4.1887], requires_grad=True)



相关文章
|
3月前
|
机器学习/深度学习 数据采集 人工智能
《零基础实践深度学习》基于线性回归实现波士顿房价预测任务1.3.3
这篇文章详细介绍了如何使用线性回归算法实现波士顿房价预测任务,包括数据读取、形状变换、集划分、归一化处理、模型设计、前向计算以及损失函数的计算等步骤,并提供了相应的Python代码实现。
 《零基础实践深度学习》基于线性回归实现波士顿房价预测任务1.3.3
|
6月前
|
机器学习/深度学习 算法 TensorFlow
机器学习算法简介:从线性回归到深度学习
【5月更文挑战第30天】本文概述了6种基本机器学习算法:线性回归、逻辑回归、决策树、支持向量机、随机森林和深度学习。通过Python示例代码展示了如何使用Scikit-learn、statsmodels、TensorFlow库进行实现。这些算法在不同场景下各有优势,如线性回归处理连续值,逻辑回归用于二分类,决策树适用于规则提取,支持向量机最大化类别间隔,随机森林集成多个决策树提升性能,而深度学习利用神经网络解决复杂模式识别问题。理解并选择合适算法对提升模型效果至关重要。
242 4
|
25天前
|
机器学习/深度学习 Python
深度学习笔记(六):如何运用梯度下降法来解决线性回归问题
这篇文章介绍了如何使用梯度下降法解决线性回归问题,包括梯度下降法的原理、线性回归的基本概念和具体的Python代码实现。
57 0
|
3月前
|
机器学习/深度学习 人工智能 自然语言处理
探索机器学习的奥秘:从线性回归到深度学习
【8月更文挑战第26天】本文将带领读者走进机器学习的世界,从基础的线性回归模型开始,逐步深入到复杂的深度学习网络。我们将探讨各种算法的原理、应用场景以及实现方法,并通过代码示例加深理解。无论你是初学者还是有一定经验的开发者,这篇文章都将为你提供有价值的知识和技能。让我们一起揭开机器学习的神秘面纱,探索这个充满无限可能的领域吧!
|
5月前
|
机器学习/深度学习 决策智能
**批量归一化(BN)**是2015年提出的深度学习优化技术,旨在解决**内部协变量偏移**和**梯度问题**。
【6月更文挑战第28天】**批量归一化(BN)**是2015年提出的深度学习优化技术,旨在解决**内部协变量偏移**和**梯度问题**。BN通过在每个小批量上执行**标准化**,然后应用学习到的γ和β参数,确保层间输入稳定性,加速训练,减少对超参数的敏感性,并作为隐含的正则化手段对抗过拟合。这提升了模型训练速度和性能,简化了初始化。
49 0
|
6月前
|
机器学习/深度学习 数据可视化 算法框架/工具
R语言深度学习KERAS循环神经网络(RNN)模型预测多输出变量时间序列
R语言深度学习KERAS循环神经网络(RNN)模型预测多输出变量时间序列
150 10
|
6月前
|
机器学习/深度学习 人工智能 分布式计算
R和Python机器学习:广义线性回归glm,样条glm,梯度增强,随机森林和深度学习模型分析
R和Python机器学习:广义线性回归glm,样条glm,梯度增强,随机森林和深度学习模型分析
|
6月前
|
机器学习/深度学习 算法
深度学习之线性回归,使用maxnet工具
深度学习之线性回归,使用maxnet工具
64 0
|
机器学习/深度学习 API TensorFlow
【深度学习】实验09 使用Keras完成线性回归
【深度学习】实验09 使用Keras完成线性回归
46 0