机器学习——PM2.5预测白话

简介: 本项目仅用于参考,提供思路和想法并非标准答案!请谨慎抄袭!

本项目仅用于参考,提供思路和想法并非标准答案!请谨慎抄袭!


作业1-PM2.5预测



项目描述


  • 本次作业的资料是从行政院环境环保署空气品质监测网所下载的观测资料。
  • 希望大家能在本作业实现 linear regression 预测出 PM2.5 的数值。


数据集介绍


  • 本次作业使用丰原站的观测记录,分成 train set 跟 test set,train set 是丰原站每个月的前 20 天所有资料。test set 则是从丰原站剩下的资料中取样出来。
  • train.csv: 每个月前 20 天的完整资料。
  • test.csv : 从剩下的资料当中取样出连续的 10 小时为一笔,前九小时的所有观测数据当作 feature,第十小时的 PM2.5 当作 answer。一共取出 240 笔不重複的 test data,请根据 feature 预测这 240 笔的 PM2.5。
  • Data 含有 18 项观测数据 AMB_TEMP, CH4, CO, NHMC, NO, NO2, NOx, O3, PM10, PM2.5, RAINFALL, RH, SO2, THC, WD_HR, WIND_DIREC, WIND_SPEED, WS_HR。


项目要求


  • 请手动实现 linear regression,方法限使用 gradient descent。
  • 禁止使用 numpy.linalg.lstsq


数据准备



环境配置/安装


数据解读

对数据进行理解和了解后数据如图:


2021031417243235.png


横向分别是24小时的数据值

竖向是12个月、每月20天、每天18种数据


!pip install --upgrade pandas


Looking in indexes: https://mirror.baidu.com/pypi/simple/
Requirement already up-to-date: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (1.2.3)
Requirement already satisfied, skipping upgrade: pytz>=2017.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pandas) (2019.3)
Requirement already satisfied, skipping upgrade: python-dateutil>=2.7.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pandas) (2.8.0)
Requirement already satisfied, skipping upgrade: numpy>=1.16.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pandas) (1.20.1)
Requirement already satisfied, skipping upgrade: six>=1.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from python-dateutil>=2.7.3->pandas) (1.15.0)


导入需要的包,读取训练集


import numpy as np
import pandas as pd
data = pd.read_csv('work/hw1_data/train.csv', encoding = 'big5')  # 使用'big5'进行编码


print(data)  # 查看数据
print(data.shape)  # 查看数据大小
         0     1     2     3     4     5     6     7     8     9  ...    14  \
0       14    14    14    13    12    12    12    12    15    17  ...    22   
1      1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8  ...   1.8   
2     0.51  0.41  0.39  0.37  0.35   0.3  0.37  0.47  0.78  0.74  ...  0.37   
3      0.2  0.15  0.13  0.12  0.11  0.06   0.1  0.13  0.26  0.23  ...   0.1   
4      0.9   0.6   0.5   1.7   1.8   1.5   1.9   2.2   6.6   7.9  ...   2.5   
...    ...   ...   ...   ...   ...   ...   ...   ...   ...   ...  ...   ...   
4315   1.8   1.8   1.8   1.8   1.8   1.7   1.7   1.8   1.8   1.8  ...   1.8   
4316    46    13    61    44    55    68    66    70    66    85  ...    59   
4317    36    55    72   327    74    52    59    83   106   105  ...    18   
4318   1.9   2.4   1.9   2.8   2.3   1.9   2.1   3.7   2.8   3.8  ...   2.3   
4319   0.7   0.8   1.8     1   1.9   1.7   2.1     2     2   1.7  ...   1.3   
        15    16    17    18    19    20    21    22    23  
0       22    21    19    17    16    15    15    15    15  
1      1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8  
2     0.37  0.47  0.69  0.56  0.45  0.38  0.35  0.36  0.32  
3     0.13  0.14  0.23  0.18  0.12   0.1  0.09   0.1  0.08  
4      2.2   2.5   2.3   2.1   1.9   1.5   1.6   1.8   1.5  
...    ...   ...   ...   ...   ...   ...   ...   ...   ...  
4315   1.8     2   2.1     2   1.9   1.9   1.9     2     2  
4316   308   327    21   100   109   108   114   108   109  
4317   311    52    54   121    97   107   118   100   105  
4318   2.6   1.3     1   1.5     1   1.7   1.5     2     2  
4319   1.7   0.7   0.4   1.1   1.4   1.3   1.6   1.8     2  
[4320 rows x 24 columns]
(4320, 24)


取需要的数值部分,将 ‘RAINFALL’ 栏位全部补 0。

对数据进行查看后发现有缺失的数据,对缺失数据进行处理,填补缺失值。


20210314130016489.png


读取的数据中有部分解释性的内容,我们不需要,可以进行提取直接忽略


2021031417301327.png


data.iloc[:,:]该函数用于处理数据,把我们需要的部分进行切割获取

data[data == 'xxx'] = 0 把xxx的内容替换成0


data = data.iloc[:, 3:]  # 从列表的第4路项开始取(不要那些没有意义的数字)
print(data)  #查看数据
print(data.shape)  #查看数据大小
print(type(data))
data[data == 'NR'] = 0  # 把'NR'项装换成0
raw_data = data.to_numpy()  # 把数据转换成numpy数组
         0     1     2     3     4     5     6     7     8     9  ...    14  \
0       14    14    14    13    12    12    12    12    15    17  ...    22   
1      1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8  ...   1.8   
2     0.51  0.41  0.39  0.37  0.35   0.3  0.37  0.47  0.78  0.74  ...  0.37   
3      0.2  0.15  0.13  0.12  0.11  0.06   0.1  0.13  0.26  0.23  ...   0.1   
4      0.9   0.6   0.5   1.7   1.8   1.5   1.9   2.2   6.6   7.9  ...   2.5   
...    ...   ...   ...   ...   ...   ...   ...   ...   ...   ...  ...   ...   
4315   1.8   1.8   1.8   1.8   1.8   1.7   1.7   1.8   1.8   1.8  ...   1.8   
4316    46    13    61    44    55    68    66    70    66    85  ...    59   
4317    36    55    72   327    74    52    59    83   106   105  ...    18   
4318   1.9   2.4   1.9   2.8   2.3   1.9   2.1   3.7   2.8   3.8  ...   2.3   
4319   0.7   0.8   1.8     1   1.9   1.7   2.1     2     2   1.7  ...   1.3   
        15    16    17    18    19    20    21    22    23  
0       22    21    19    17    16    15    15    15    15  
1      1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8   1.8  
2     0.37  0.47  0.69  0.56  0.45  0.38  0.35  0.36  0.32  
3     0.13  0.14  0.23  0.18  0.12   0.1  0.09   0.1  0.08  
4      2.2   2.5   2.3   2.1   1.9   1.5   1.6   1.8   1.5  
...    ...   ...   ...   ...   ...   ...   ...   ...   ...  
4315   1.8     2   2.1     2   1.9   1.9   1.9     2     2  
4316   308   327    21   100   109   108   114   108   109  
4317   311    52    54   121    97   107   118   100   105  
4318   2.6   1.3     1   1.5     1   1.7   1.5     2     2  
4319   1.7   0.7   0.4   1.1   1.4   1.3   1.6   1.8     2  
[4320 rows x 24 columns]
(4320, 24)
<class 'pandas.core.frame.DataFrame'>
print(raw_data.shape)  # 查看数组大小
print(type(raw_data))  # 查看类型
(4320, 24)
<class 'numpy.ndarray'>


20210314175459992.png


从原先的24*(18*20*12)转换成12*18*(20*24)

month_data = {}
for month in range(12):
    sample = np.empty([18, 480])  # 新建np数组大小是[18, 480]内容随机
    for day in range(20):
        sample[:, day * 24 : (day + 1) * 24] = raw_data[18 * (20 * month + day) : 18 * (20 * month + day + 1), :]
    month_data[month] = sample


# print(len(month_data),len(month_data[0]))  # 大小查看
# print(month_data)  # 数据查看
print(month_data[month])
print(month_data[month].shape)
[[ 23.    23.    23.   ...  13.    13.    13.  ]
 [  1.6    1.7    1.7  ...   1.8    1.8    1.8 ]
 [  0.22   0.2    0.18 ...   0.51   0.57   0.56]
 ...
 [ 93.    50.    99.   ... 118.   100.   105.  ]
 [  1.8    2.1    3.2  ...   1.5    2.     2.  ]
 [  1.3    0.9    1.   ...   1.6    1.8    2.  ]]
(18, 480)


20210314183314551.png


每个月会有 480hrs,每 9 小时形成一个 data,每个月会有 471 个 data,故总资料数为 471 * 12 笔,而每笔 data 有 9 * 18 的 features (一小时 18 个 features * 9 小时)。


  • 471次/月 * 12个月就是我们得到的数据量,而每次的数据量是9小时 * 18种数据


对应的 target 则有 471 * 12 个(第 10 个小时的 PM2.5)


  • target是下一个时刻的PM2.5的值


解析:

一个月的数据就是20天 * 24小时 共计 480小时的数据

按照9小时一组进行处理应该得到 480-9+1== 472

但是最后一组数据是没有最后对应的y的值的

所以:是472 - 1 == 471


  • 为什么是9次一组不是10次???
    题目要求:前九小时的所有观测数据当作 feature,第十小时的 PM2.5 当作 answer。
    根据要求9次数据为训练集第10次的为比较值。
  • 为什么不10个一组然后再做数据处理???
    这个建议不是不行,但是如果10个一组就少了一点点的训练量
    这个具体看自己理解。


x = np.empty([12 * 471, 18 * 9], dtype = float)  
y = np.empty([12 * 471, 1], dtype = float)
for month in range(12):
    for day in range(20):
        for hour in range(24):
            if day == 19 and hour > 14:
                continue
            x[month * 471 + day * 24 + hour, :] = month_data[month][:,day * 24 + hour : day * 24 + hour + 9].reshape(1, -1) #vector dim:18*9 (9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9)
            y[month * 471 + day * 24 + hour, 0] = month_data[month][9, day * 24 + hour + 9] #value
print(x)
print(x.shape)
print(y)
print(y.shape)
[[14.  14.  14.  ...  2.   2.   0.5]
 [14.  14.  13.  ...  2.   0.5  0.3]
 [14.  13.  12.  ...  0.5  0.3  0.8]
 ...
 [17.  18.  19.  ...  1.1  1.4  1.3]
 [18.  19.  18.  ...  1.4  1.3  1.6]
 [19.  18.  17.  ...  1.3  1.6  1.8]]
(5652, 162)
[[30.]
 [41.]
 [44.]
 ...
 [17.]
 [24.]
 [29.]]
(5652, 1)


求值的离散程度

计算x值的方差和平均值

然后计算每一个值的离散程度

np.mean(x, axis = 0):平均值

np.std(x, axis = 0): 方差

离散程度 = 平均差/方差


mean_x = np.mean(x, axis = 0)  # 18 * 9 
print(mean_x.shape)
std_x = np.std(x, axis = 0)  # 18 * 9 
print(std_x.shape)
for i in range(len(x)):  # 12 * 471
    for j in range(len(x[0])):  # 18 * 9 
        if std_x[j] != 0:
            x[i][j] = (x[i][j] - mean_x[j]) / std_x[j]
x.shape


(162,)
(162,)
(5652, 162)


数据集生成

把数据集安装一定的比例进行区分,一部分生成训练集一部分生成测试集(建议8:2)

注:查看对应的数据是否一样!


import math
x_train_set = x[: math.floor(len(x) * 0.8), :]
y_train_set = y[: math.floor(len(y) * 0.8), :]
x_validation = x[math.floor(len(x) * 0.8): , :]
y_validation = y[math.floor(len(y) * 0.8): , :]
print(x_train_set)
print(y_train_set)
print(x_validation)
print(y_validation)
print(len(x_train_set))
print(len(y_train_set))
print(len(x_validation))
print(len(y_validation))


[[-1.35825331 -1.35883937 -1.359222   ...  0.26650729  0.2656797
  -1.14082131]
 [-1.35825331 -1.35883937 -1.51819928 ...  0.26650729 -1.13963133
  -1.32832904]
 [-1.35825331 -1.51789368 -1.67717656 ... -1.13923451 -1.32700613
  -0.85955971]
 ...
 [ 0.86929969  0.70886668  0.38952809 ...  1.39110073  0.2656797
  -0.39079039]
 [ 0.71018876  0.39075806  0.07157353 ...  0.26650729 -0.39013211
  -0.39079039]
 [ 0.3919669   0.07264944  0.07157353 ... -0.38950555 -0.39013211
  -0.85955971]]
[[30.]
 [41.]
 [44.]
 ...
 [ 7.]
 [ 5.]
 [14.]]
[[ 0.07374504  0.07264944  0.07157353 ... -0.38950555 -0.85856912
  -0.57829812]
 [ 0.07374504  0.07264944  0.23055081 ... -0.85808615 -0.57750692
   0.54674825]
 [ 0.07374504  0.23170375  0.23055081 ... -0.57693779  0.54674191
  -0.1095288 ]
 ...
 [-0.88092053 -0.72262212 -0.56433559 ... -0.57693779 -0.29644471
  -0.39079039]
 [-0.7218096  -0.56356781 -0.72331287 ... -0.29578943 -0.39013211
  -0.1095288 ]
 [-0.56269867 -0.72262212 -0.88229015 ... -0.38950555 -0.10906991
   0.07797893]]
[[13.]
 [24.]
 [22.]
 ...
 [17.]
 [24.]
 [29.]]
4521
4521
1131
1131


因为常数项的存在,所以 dimension (dim) 需要多加一栏;eps 项是避免 adagrad 的分母为 0 而加的极小数值。


每一个 dimension (dim) 会对应到各自的 gradient, weight (w),透过一次次的 iteration (iter_time) 学习。


采用均方根误差


dim = 18 * 9 + 1  # 18个数据*9次+1(常量)个
w = np.zeros([dim, 1])  # 生成数据是0的数组
x = np.concatenate((np.ones([12 * 471, 1]), x), axis = 1).astype(float)  # 拼接1和x数组
learning_rate = 100  # 学习率
iter_time = 1000  # 学习次数
adagrad = np.zeros([dim, 1])  # 生成数据是0的数组
eps = 0.0000000001  
for t in range(iter_time):
    loss = np.sqrt(np.sum(np.power(np.dot(x, w) - y, 2))/471/12)  # rmse
    if(t%100==0):  # 100轮输出
        print(str(t) + ":" + str(loss))
    gradient = 2 * np.dot(x.transpose(), np.dot(x, w) - y)  # dim*1
    adagrad += gradient ** 2
    w = w - learning_rate * gradient / np.sqrt(adagrad + eps)
np.save('weight.npy', w) # 保存文件
w


0:27.071214829194115
100:33.78905859777454
200:19.913751298197095
300:13.531068193689693
400:10.645466158446172
500:9.277353455475065
600:8.518042045956502
700:8.014061987588425
800:7.636756824775692
900:7.336563740371125
array([[ 2.13740269e+01],
       [ 3.58888909e+00],
       [ 4.56386323e+00],
       [ 2.16307023e+00],
       [-6.58545223e+00],
       [-3.38885580e+01],
       [ 3.22235518e+01],
       [ 3.49340354e+00],
       [-4.60308671e+00],
       [-1.02374754e+00],
       [-3.96791501e-01],
       [-1.06908800e-01],
       [ 2.22488184e-01],
       [ 8.99634117e-02],
       [ 1.31243105e-01],
       [ 2.15894989e-02],
       [-1.52867263e-01],
       [ 4.54087776e-02],
       [ 5.20999235e-01],
       [ 1.60824213e-01],
       [-3.17709451e-02],
       [ 1.28529025e-02],
       [-1.76839437e-01],
       [ 1.71241371e-01],
       [-1.31190032e-01],
       [-3.51614451e-02],
       [ 1.00826192e-01],
       [ 3.45018257e-01],
       [ 4.00130315e-02],
       [ 2.54331382e-02],
       [-5.04425219e-01],
       [ 3.71483018e-01],
       [ 8.46357671e-01],
       [-8.11920428e-01],
       [-8.00217575e-02],
       [ 1.52737711e-01],
       [ 2.64915130e-01],
       [-5.19860416e-02],
       [-2.51988315e-01],
       [ 3.85246517e-01],
       [ 1.65431451e-01],
       [-7.83633314e-02],
       [-2.89457231e-01],
       [ 1.77615023e-01],
       [ 3.22506948e-01],
       [-4.59955256e-01],
       [-3.48635358e-02],
       [-5.81764363e-01],
       [-6.43394528e-02],
       [-6.32876949e-01],
       [ 6.36624507e-02],
       [ 8.31592506e-02],
       [-4.45157961e-01],
       [-2.34526366e-01],
       [ 9.86608594e-01],
       [ 2.65230652e-01],
       [ 3.51938093e-02],
       [ 3.07464334e-01],
       [-1.04311239e-01],
       [-6.49166901e-02],
       [ 2.11224757e-01],
       [-2.43159815e-01],
       [-1.31285604e-01],
       [ 1.09045810e+00],
       [-3.97913710e-02],
       [ 9.19563678e-01],
       [-9.44824150e-01],
       [-5.04137735e-01],
       [ 6.81272939e-01],
       [-1.34494828e+00],
       [-2.68009542e-01],
       [ 4.36204342e-02],
       [ 1.89619513e+00],
       [-3.41873873e-01],
       [ 1.89162461e-01],
       [ 1.73251268e-02],
       [ 3.14431930e-01],
       [-3.40828467e-01],
       [ 4.92385651e-01],
       [ 9.29634214e-02],
       [-4.50983589e-01],
       [ 1.47456584e+00],
       [-3.03417236e-02],
       [ 7.71229328e-02],
       [ 6.38314494e-01],
       [-7.93287087e-01],
       [ 8.82877506e-01],
       [ 3.18965610e+00],
       [-5.75671706e+00],
       [ 1.60748945e+00],
       [ 1.36142440e+01],
       [ 1.50029111e-01],
       [-4.78389603e-02],
       [-6.29463755e-02],
       [-2.85383032e-02],
       [-3.01562821e-01],
       [ 4.12058013e-01],
       [-6.77534154e-02],
       [-1.00985479e-01],
       [-1.68972973e-01],
       [ 1.64093233e+00],
       [ 1.89670371e+00],
       [ 3.94713816e-01],
       [-4.71231449e+00],
       [-7.42760774e+00],
       [ 6.19781936e+00],
       [ 3.53986244e+00],
       [-9.56245861e-01],
       [-1.04372792e+00],
       [-4.92863713e-01],
       [ 6.31608790e-01],
       [-4.85175956e-01],
       [ 2.58400216e-01],
       [ 9.43846795e-02],
       [-1.29323184e-01],
       [-3.81235287e-01],
       [ 3.86819479e-01],
       [ 4.04211627e-01],
       [ 3.75568914e-01],
       [ 1.83512261e-01],
       [-8.01417708e-02],
       [-3.10188597e-01],
       [-3.96124612e-01],
       [ 3.66227853e-01],
       [ 1.79488593e-01],
       [-3.14477051e-01],
       [-2.37611443e-01],
       [ 3.97076104e-02],
       [ 1.38775912e-01],
       [-3.84015069e-02],
       [-5.47557119e-02],
       [ 4.19975207e-01],
       [ 4.46120687e-01],
       [-4.31074826e-01],
       [-8.74450768e-02],
       [-5.69534264e-02],
       [-7.23980157e-02],
       [-1.39880128e-02],
       [ 1.40489658e-01],
       [-2.44952334e-01],
       [ 1.83646770e-01],
       [-1.64135512e-01],
       [-7.41216452e-02],
       [-9.71414213e-02],
       [ 1.98829041e-02],
       [-4.46965919e-01],
       [-2.63440959e-01],
       [ 1.52924043e-01],
       [ 6.52532847e-02],
       [ 7.06818266e-01],
       [ 9.73757051e-02],
       [-3.35687787e-01],
       [-2.26559165e-01],
       [-3.00117086e-01],
       [ 1.24185231e-01],
       [ 4.18872344e-01],
       [-2.51891946e-01],
       [-1.29095731e-01],
       [-5.57512471e-01],
       [ 8.76239582e-02],
       [ 3.02594902e-01],
       [-4.23463160e-01],
       [ 4.89922051e-01]])


加载 test data,并且以相似于训练资料预先处理和特徵萃取的方式处理,使 test data 形成 240 个维度为 18 * 9 + 1 的资料。


# 对tast_data做同等处理
testdata = pd.read_csv('work/hw1_data/test.csv', header = None, encoding = 'big5')
test_data = testdata.iloc[:, 2:]
test_data[test_data == 'NR'] = 0
test_data = test_data.to_numpy()
test_x = np.empty([240, 18*9], dtype = float)
for i in range(240):
    test_x[i, :] = test_data[18 * i: 18* (i + 1), :].reshape(1, -1)
for i in range(len(test_x)):
    for j in range(len(test_x[0])):
        if std_x[j] != 0:
            test_x[i][j] = (test_x[i][j] - mean_x[j]) / std_x[j]
test_x = np.concatenate((np.ones([240, 1]), test_x), axis = 1).astype(float)
test_x
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/ipykernel_launcher.py:3: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  This is separate from the ipykernel package so we can avoid doing imports until
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/pandas/core/frame.py:3215: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame
See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self._where(-key, value, inplace=True)
array([[ 1.        , -0.24447681, -0.24545919, ..., -0.67065391,
        -1.04594393,  0.07797893],
       [ 1.        , -1.35825331, -1.51789368, ...,  0.17279117,
        -0.10906991, -0.48454426],
       [ 1.        ,  1.5057434 ,  1.34508393, ..., -1.32666675,
        -1.04594393, -0.57829812],
       ...,
       [ 1.        ,  0.3919669 ,  0.54981237, ...,  0.26650729,
        -0.20275731,  1.20302531],
       [ 1.        , -1.8355861 , -1.8360023 , ..., -1.04551839,
        -1.13963133, -1.14082131],
       [ 1.        , -1.35825331, -1.35883937, ...,  2.98427476,
         3.26367657,  1.76554849]])


有了 weight 和测试资料就可以预测 target。


np.dot(test_x, w):预测!


w = np.load('work/weight.npy')  # 读取文档
ans_y = np.dot(test_x, w)  # 预测
ans_y
array([[ 5.17496040e+00],
       [ 1.83062143e+01],
       [ 2.04912181e+01],
       [ 1.15239429e+01],
       [ 2.66160568e+01],
       [ 2.05313481e+01],
       [ 2.19065510e+01],
       [ 3.17364687e+01],
       [ 1.33916741e+01],
       [ 6.44564665e+01],
       [ 2.02645688e+01],
       [ 1.53585761e+01],
       [ 6.85894728e+01],
       [ 4.84281137e+01],
       [ 1.87023338e+01],
       [ 1.01885957e+01],
       [ 3.07403629e+01],
       [ 7.11322178e+01],
       [-4.13051739e+00],
       [ 1.82356940e+01],
       [ 3.85789223e+01],
       [ 7.13115197e+01],
       [ 7.41034816e+00],
       [ 1.87179553e+01],
       [ 1.49372503e+01],
       [ 3.67197367e+01],
       [ 1.79616970e+01],
       [ 7.57894629e+01],
       [ 1.23093102e+01],
       [ 5.62953517e+01],
       [ 2.51131609e+01],
       [ 4.61024867e+00],
       [ 2.48377055e+00],
       [ 2.47594223e+01],
       [ 3.04802805e+01],
       [ 3.84639307e+01],
       [ 4.42023106e+01],
       [ 3.00868360e+01],
       [ 4.04736750e+01],
       [ 2.92264799e+01],
       [ 5.60645605e+00],
       [ 3.86660161e+01],
       [ 3.46102134e+01],
       [ 4.83896975e+01],
       [ 1.47572477e+01],
       [ 3.44668201e+01],
       [ 2.74831069e+01],
       [ 1.20008794e+01],
       [ 2.13780362e+01],
       [ 2.85444031e+01],
       [ 2.01655138e+01],
       [ 1.07966781e+01],
       [ 2.21710358e+01],
       [ 5.34462631e+01],
       [ 1.22195811e+01],
       [ 4.33009685e+01],
       [ 3.21823351e+01],
       [ 2.25672175e+01],
       [ 5.67395142e+01],
       [ 2.07450529e+01],
       [ 1.50288546e+01],
       [ 3.98553016e+01],
       [ 1.29753407e+01],
       [ 5.17416596e+01],
       [ 1.87833696e+01],
       [ 1.23487528e+01],
       [ 1.56336237e+01],
       [-5.88714707e-02],
       [ 4.15080111e+01],
       [ 3.15487475e+01],
       [ 1.86042512e+01],
       [ 3.74768197e+01],
       [ 5.65203907e+01],
       [ 6.58787719e+00],
       [ 1.22293397e+01],
       [ 5.20369640e+00],
       [ 4.79273751e+01],
       [ 1.30207057e+01],
       [ 1.71103017e+01],
       [ 2.06032345e+01],
       [ 2.12844816e+01],
       [ 3.86929353e+01],
       [ 3.00207167e+01],
       [ 8.87674067e+01],
       [ 3.59847002e+01],
       [ 2.67569136e+01],
       [ 2.39635168e+01],
       [ 3.27472428e+01],
       [ 2.21890438e+01],
       [ 2.09921589e+01],
       [ 2.95559943e+01],
       [ 4.09921689e+01],
       [ 8.62511781e+00],
       [ 3.23214718e+01],
       [ 4.65980444e+01],
       [ 2.28840708e+01],
       [ 3.15181297e+01],
       [ 1.11982335e+01],
       [ 2.85274366e+01],
       [ 2.91150680e-01],
       [ 1.79669611e+01],
       [ 2.71241639e+01],
       [ 1.13982328e+01],
       [ 1.64264269e+01],
       [ 2.34252610e+01],
       [ 4.06160827e+01],
       [ 2.58641250e+01],
       [ 5.42273695e+00],
       [ 1.07949211e+01],
       [ 7.28621369e+01],
       [ 4.80228371e+01],
       [ 1.57468083e+01],
       [ 2.46704106e+01],
       [ 1.28277933e+01],
       [ 1.01580576e+01],
       [ 2.72692233e+01],
       [ 2.92087386e+01],
       [ 8.83533962e+00],
       [ 2.00510881e+01],
       [ 2.02123337e+01],
       [ 7.99060093e+01],
       [ 1.80616143e+01],
       [ 3.05428093e+01],
       [ 2.59807924e+01],
       [ 5.21257727e+00],
       [ 3.03556973e+01],
       [ 7.76832289e+00],
       [ 1.53282683e+01],
       [ 2.26663657e+01],
       [ 6.27420542e+01],
       [ 1.89507804e+01],
       [ 1.90763556e+01],
       [ 6.13715741e+01],
       [ 1.58845621e+01],
       [ 1.34094181e+01],
       [ 8.48772484e-01],
       [ 7.83499672e+00],
       [ 5.70128290e+01],
       [ 2.56079968e+01],
       [ 4.96170473e+00],
       [ 3.64148790e+01],
       [ 2.87900067e+01],
       [ 4.91941210e+01],
       [ 4.03068699e+01],
       [ 1.33161806e+01],
       [ 2.76610119e+01],
       [ 1.71580275e+01],
       [ 4.96872626e+01],
       [ 2.30302723e+01],
       [ 3.92409365e+01],
       [ 1.31967539e+01],
       [ 5.94889370e+00],
       [ 2.58216090e+01],
       [ 8.25863421e+00],
       [ 1.91463205e+01],
       [ 4.31824865e+01],
       [ 6.71784358e+00],
       [ 3.38696152e+01],
       [ 1.53699378e+01],
       [ 1.69390450e+01],
       [ 3.78853368e+01],
       [ 1.92024845e+01],
       [ 9.05950472e+00],
       [ 1.02833996e+01],
       [ 4.86724471e+01],
       [ 3.05877162e+01],
       [ 2.47740990e+00],
       [ 1.28116039e+01],
       [ 7.03247898e+01],
       [ 1.48409677e+01],
       [ 6.88655876e+01],
       [ 4.27419924e+01],
       [ 2.40002615e+01],
       [ 2.34207249e+01],
       [ 6.16721244e+01],
       [ 2.54942028e+01],
       [ 1.90048098e+01],
       [ 3.48866829e+01],
       [ 9.40231340e+00],
       [ 2.95200113e+01],
       [ 1.45739659e+01],
       [ 9.12556314e+00],
       [ 5.28125840e+01],
       [ 4.50395380e+01],
       [ 1.74524347e+01],
       [ 3.84939353e+01],
       [ 2.70389191e+01],
       [ 6.55817097e+01],
       [ 7.03730638e+00],
       [ 5.27144771e+01],
       [ 3.82064593e+01],
       [ 2.11698011e+01],
       [ 3.02475569e+01],
       [ 2.71442299e+00],
       [ 1.99329326e+01],
       [-3.41333234e+00],
       [ 3.24459994e+01],
       [ 1.05829730e+01],
       [ 2.17752257e+01],
       [ 6.24652921e+01],
       [ 2.41329437e+01],
       [ 2.62012396e+01],
       [ 6.37444772e+01],
       [ 2.83429777e+00],
       [ 1.43792470e+01],
       [ 9.36985073e+00],
       [ 9.88116661e+00],
       [ 3.49494536e+00],
       [ 1.22608049e+02],
       [ 2.10835130e+01],
       [ 1.75322206e+01],
       [ 2.01830983e+01],
       [ 3.63931322e+01],
       [ 3.49351512e+01],
       [ 1.88303127e+01],
       [ 3.83445555e+01],
       [ 7.79166341e+01],
       [ 1.79532355e+00],
       [ 1.34458279e+01],
       [ 3.61311556e+01],
       [ 1.51504035e+01],
       [ 1.29418483e+01],
       [ 1.13125241e+02],
       [ 1.52246047e+01],
       [ 1.48240260e+01],
       [ 5.92673537e+01],
       [ 1.05836953e+01],
       [ 2.09930626e+01],
       [ 9.78936588e+00],
       [ 4.77118001e+00],
       [ 4.79278069e+01],
       [ 1.23994384e+01],
       [ 4.81464766e+01],
       [ 4.04663804e+01],
       [ 1.69405903e+01],
       [ 4.12665445e+01],
       [ 6.90278920e+01],
       [ 4.03462492e+01],
       [ 1.43137440e+01],
       [ 1.57707266e+01]])


保存预测文件


import csv
with open('work/submit.csv', mode='w', newline='') as submit_file:
    csv_writer = csv.writer(submit_file)
    header = ['id', 'value']
    print(header)
    csv_writer.writerow(header)
    for i in range(240):
        row = ['id_' + str(i), ans_y[i][0]]
        csv_writer.writerow(row)
        print(row)


['id', 'value']
['id_0', 5.174960398984736]
['id_1', 18.30621425352788]
['id_2', 20.491218094180528]
['id_3', 11.523942869805332]
['id_4', 26.616056752306132]
['id_5', 20.531348081761223]
['id_6', 21.906551018797376]
['id_7', 31.736468747068834]
['id_8', 13.391674055111736]
['id_9', 64.45646650291954]
['id_10', 20.264568836159437]
['id_11', 15.35857607736122]
['id_12', 68.58947276926726]
['id_13', 48.428113747457196]
['id_14', 18.702333824193218]
['id_15', 10.188595737466716]
['id_16', 30.74036285982044]
['id_17', 71.13221776355108]
['id_18', -4.130517391262444]
['id_19', 18.235694016428695]
['id_20', 38.578922275007756]
['id_21', 71.31151972531332]
['id_22', 7.410348162634086]
['id_23', 18.717955330321395]
['id_24', 14.93725026008458]
['id_25', 36.719736694705325]
['id_26', 17.96169700566271]
['id_27', 75.78946287210539]
['id_28', 12.309310248614484]
['id_29', 56.2953517396496]
['id_30', 25.113160865661484]
['id_31', 4.610248674094053]
['id_32', 2.4837705545150315]
['id_33', 24.75942226132128]
['id_34', 30.480280465591196]
['id_35', 38.46393074642666]
['id_36', 44.20231060933005]
['id_37', 30.08683601986601]
['id_38', 40.47367501574008]
['id_39', 29.22647990231738]
['id_40', 5.606456054343949]
['id_41', 38.666016078789596]
['id_42', 34.61021343187721]
['id_43', 48.38969750738482]
['id_44', 14.757247666944172]
['id_45', 34.46682011087208]
['id_46', 27.48310687418436]
['id_47', 12.000879378154043]
['id_48', 21.378036151603794]
['id_49', 28.54440309166328]
['id_50', 20.16551381841159]
['id_51', 10.796678149746501]
['id_52', 22.171035755750125]
['id_53', 53.446263109352266]
['id_54', 12.21958112161002]
['id_55', 43.30096845517155]
['id_56', 32.1823351032854]
['id_57', 22.5672175145708]
['id_58', 56.73951416554704]
['id_59', 20.745052945295473]
['id_60', 15.028854557473265]
['id_61', 39.8553015903851]
['id_62', 12.975340680728284]
['id_63', 51.74165959283004]
['id_64', 18.783369632539877]
['id_65', 12.348752842777712]
['id_66', 15.633623653541925]
['id_67', -0.05887147068500154]
['id_68', 41.50801107307596]
['id_69', 31.548747530656026]
['id_70', 18.604251157547075]
['id_71', 37.4768197248807]
['id_72', 56.52039065762305]
['id_73', 6.58787719352195]
['id_74', 12.229339737435051]
['id_75', 5.203696404134638]
['id_76', 47.92737510380059]
['id_77', 13.020705685594661]
['id_78', 17.110301693903597]
['id_79', 20.603234531002048]
['id_80', 21.284481560784613]
['id_81', 38.69293529051181]
['id_82', 30.020716675725847]
['id_83', 88.76740666723548]
['id_84', 35.984700239668264]
['id_85', 26.756913553477187]
['id_86', 23.963516843564403]
['id_87', 32.747242828083074]
['id_88', 22.18904375531994]
['id_89', 20.992158853626545]
['id_90', 29.555994316645446]
['id_91', 40.99216886651781]
['id_92', 8.625117809911558]
['id_93', 32.3214718088779]
['id_94', 46.59804436536759]
['id_95', 22.88407082672354]
['id_96', 31.518129728251655]
['id_97', 11.19823347976612]
['id_98', 28.527436642529608]
['id_99', 0.2911506800896443]
['id_100', 17.96696107953969]
['id_101', 27.124163929470143]
['id_102', 11.398232780652847]
['id_103', 16.426426865673527]
['id_104', 23.42526104692219]
['id_105', 40.6160826705684]
['id_106', 25.8641250265604]
['id_107', 5.422736951672389]
['id_108', 10.794921122256104]
['id_109', 72.86213692992126]
['id_110', 48.022837059481375]
['id_111', 15.746808276902996]
['id_112', 24.67041061417795]
['id_113', 12.827793326536716]
['id_114', 10.158057570240526]
['id_115', 27.269223342020982]
['id_116', 29.208738577932458]
['id_117', 8.835339619930767]
['id_118', 20.05108813712978]
['id_119', 20.212333743764248]
['id_120', 79.9060092987056]
['id_121', 18.061614288263595]
['id_122', 30.542809341304345]
['id_123', 25.98079237772804]
['id_124', 5.212577268164767]
['id_125', 30.355697305856214]
['id_126', 7.768322888914637]
['id_127', 15.328268255393336]
['id_128', 22.66636571769797]
['id_129', 62.742054211090085]
['id_130', 18.950780367987996]
['id_131', 19.076355630838545]
['id_132', 61.37157409163711]
['id_133', 15.884562052629718]
['id_134', 13.409418077705558]
['id_135', 0.8487724836112842]
['id_136', 7.834996717304126]
['id_137', 57.01282901179679]
['id_138', 25.607996751813804]
['id_139', 4.9617047292420855]
['id_140', 36.414879039062775]
['id_141', 28.790006721975917]
['id_142', 49.19412096197634]
['id_143', 40.3068698557345]
['id_144', 13.316180593982658]
['id_145', 27.661011875229164]
['id_146', 17.158027524366766]
['id_147', 49.68726256929682]
['id_148', 23.03027229160478]
['id_149', 39.240936524842766]
['id_150', 13.19675388941254]
['id_151', 5.948893701039413]
['id_152', 25.82160897630425]
['id_153', 8.258634214291634]
['id_154', 19.146320517225597]
['id_155', 43.18248652651674]
['id_156', 6.717843578093033]
['id_157', 33.869615246810646]
['id_158', 15.3699378469818]
['id_159', 16.939044973551923]
['id_160', 37.88533679463485]
['id_161', 19.202484541054467]
['id_162', 9.059504715654725]
['id_163', 10.283399610648509]
['id_164', 48.672447125698284]
['id_165', 30.58771621323082]
['id_166', 2.4774098975321657]
['id_167', 12.811603937805932]
['id_168', 70.32478980976464]
['id_169', 14.840967694067068]
['id_170', 68.8655875667886]
['id_171', 42.74199244486634]
['id_172', 24.000261542920168]
['id_173', 23.420724860321446]
['id_174', 61.672124435682356]
['id_175', 25.494202845059192]
['id_176', 19.004809786869096]
['id_177', 34.88668288189683]
['id_178', 9.40231339837975]
['id_179', 29.520011314408027]
['id_180', 14.573965885700483]
['id_181', 9.125563143203598]
['id_182', 52.81258399813187]
['id_183', 45.03953799438962]
['id_184', 17.452434679183295]
['id_185', 38.49393527971433]
['id_186', 27.03891909264382]
['id_187', 65.58170967424583]
['id_188', 7.0373063807695795]
['id_189', 52.71447713411572]
['id_190', 38.20645933704977]
['id_191', 21.16980105955784]
['id_192', 30.247556879488393]
['id_193', 2.714422989716304]
['id_194', 19.93293258764082]
['id_195', -3.413332337603944]
['id_196', 32.44599940281316]
['id_197', 10.582973029979941]
['id_198', 21.77522570725845]
['id_199', 62.465292065677886]
['id_200', 24.13294368731649]
['id_201', 26.201239647400964]
['id_202', 63.74447723440287]
['id_203', 2.83429777412905]
['id_204', 14.37924698697884]
['id_205', 9.369850731753894]
['id_206', 9.881166613595411]
['id_207', 3.4949453589721426]
['id_208', 122.6080493792178]
['id_209', 21.083513014480573]
['id_210', 17.53222059945511]
['id_211', 20.183098344597003]
['id_212', 36.39313221228185]
['id_213', 34.93515120529068]
['id_214', 18.83031266145864]
['id_215', 38.34455552272332]
['id_216', 77.91663413807038]
['id_217', 1.7953235508882215]
['id_218', 13.445827939135775]
['id_219', 36.131155590412135]
['id_220', 15.150403498166307]
['id_221', 12.941848334417926]
['id_222', 113.12524093786391]
['id_223', 15.224604677934382]
['id_224', 14.824025968612034]
['id_225', 59.267353688540446]
['id_226', 10.583695290718481]
['id_227', 20.993062563532174]
['id_228', 9.789365880830381]
['id_229', 4.77118000870597]
['id_230', 47.92780690481291]
['id_231', 12.399438394751039]
['id_232', 48.14647656264414]
['id_233', 40.46638039656415]
['id_234', 16.94059027033294]
['id_235', 41.26654448941875]
['id_236', 69.02789203372899]
['id_237', 40.34624924412241]
['id_238', 14.313743982871129]
['id_239', 15.770726634219828]


传说中的飞桨社区最菜代码人,让我们一起努力!

记住:三岁出品必是精品 (不要脸系列


目录
相关文章
|
机器学习/深度学习
【阿旭机器学习实战】【13】决策树分类模型实战:泰坦尼克号生存预测
【阿旭机器学习实战】【13】决策树分类模型实战:泰坦尼克号生存预测
【阿旭机器学习实战】【13】决策树分类模型实战:泰坦尼克号生存预测
|
机器学习/深度学习 算法
阿旭机器学习实战【4】KNN算法实战练习1:利用KNN算法预测某人对你喜欢程度
阿旭机器学习实战【4】KNN算法实战练习1:利用KNN算法预测某人对你喜欢程度
阿旭机器学习实战【4】KNN算法实战练习1:利用KNN算法预测某人对你喜欢程度
|
机器学习/深度学习 算法 索引
阿旭机器学习实战【3】KNN算法进行年收入预测
阿旭机器学习实战【3】KNN算法进行年收入预测
阿旭机器学习实战【3】KNN算法进行年收入预测
|
机器学习/深度学习 算法 数据挖掘
机器学习实战二:波士顿房价预测 Boston Housing(下)
机器学习实战二:波士顿房价预测 Boston Housing(下)
机器学习实战二:波士顿房价预测 Boston Housing(下)
|
机器学习/深度学习 移动开发 Windows
机器学习奥林匹克-身体健康与幸福之心脏病预测
机器学习奥林匹克-身体健康与幸福之心脏病预测
415 0
机器学习奥林匹克-身体健康与幸福之心脏病预测
|
机器学习/深度学习 数据采集 人工智能
就离谱!使用机器学习预测2022世界杯:小组赛挺准,但冠亚季军都错了 ⛵
本文使用机器学习建模对 FIFA 2022世界杯结果进行了预测,赛后将其与真实结果进行比较,可以看出:小组赛到1/4决赛的预测准确率很高,半决赛和决赛的预测准确率为0,冠亚季军无一预测准确。
139 0
就离谱!使用机器学习预测2022世界杯:小组赛挺准,但冠亚季军都错了 ⛵
|
机器学习/深度学习 算法
​Kaggle M5 Forecasting:传统预测方法与机器学习预测方法对比(三)
​Kaggle M5 Forecasting:传统预测方法与机器学习预测方法对比(三)
237 0
​Kaggle M5 Forecasting:传统预测方法与机器学习预测方法对比(三)
|
机器学习/深度学习 vr&ar
​Kaggle M5 Forecasting:传统预测方法与机器学习预测方法对比(二)
​Kaggle M5 Forecasting:传统预测方法与机器学习预测方法对比(二)
229 0
​Kaggle M5 Forecasting:传统预测方法与机器学习预测方法对比(二)
|
机器学习/深度学习 存储
​Kaggle M5 Forecasting:传统预测方法与机器学习预测方法对比(一)
​Kaggle M5 Forecasting:传统预测方法与机器学习预测方法对比(一)
396 0
​Kaggle M5 Forecasting:传统预测方法与机器学习预测方法对比(一)
|
机器学习/深度学习 数据可视化 安全
机器学习实战:意大利Covid-19病毒感染数学模型及预测
机器学习实战:意大利Covid-19病毒感染数学模型及预测
254 0
机器学习实战:意大利Covid-19病毒感染数学模型及预测