本项目仅用于参考,提供思路和想法并非标准答案!请谨慎抄袭!
作业1-PM2.5预测
项目描述
- 本次作业的资料是从行政院环境环保署空气品质监测网所下载的观测资料。
- 希望大家能在本作业实现 linear regression 预测出 PM2.5 的数值。
数据集介绍
- 本次作业使用丰原站的观测记录,分成 train set 跟 test set,train set 是丰原站每个月的前 20 天所有资料。test set 则是从丰原站剩下的资料中取样出来。
- train.csv: 每个月前 20 天的完整资料。
- test.csv : 从剩下的资料当中取样出连续的 10 小时为一笔,前九小时的所有观测数据当作 feature,第十小时的 PM2.5 当作 answer。一共取出 240 笔不重複的 test data,请根据 feature 预测这 240 笔的 PM2.5。
- Data 含有 18 项观测数据 AMB_TEMP, CH4, CO, NHMC, NO, NO2, NOx, O3, PM10, PM2.5, RAINFALL, RH, SO2, THC, WD_HR, WIND_DIREC, WIND_SPEED, WS_HR。
项目要求
- 请手动实现 linear regression,方法限使用 gradient descent。
- 禁止使用 numpy.linalg.lstsq
数据准备
无
环境配置/安装
数据解读
对数据进行理解和了解后数据如图:
横向分别是24小时的数据值
竖向是12个月、每月20天、每天18种数据
!pip install --upgrade pandas
Looking in indexes: https://mirror.baidu.com/pypi/simple/ Requirement already up-to-date: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (1.2.3) Requirement already satisfied, skipping upgrade: pytz>=2017.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pandas) (2019.3) Requirement already satisfied, skipping upgrade: python-dateutil>=2.7.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pandas) (2.8.0) Requirement already satisfied, skipping upgrade: numpy>=1.16.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pandas) (1.20.1) Requirement already satisfied, skipping upgrade: six>=1.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from python-dateutil>=2.7.3->pandas) (1.15.0)
导入需要的包,读取训练集
import numpy as np import pandas as pd data = pd.read_csv('work/hw1_data/train.csv', encoding = 'big5') # 使用'big5'进行编码
print(data) # 查看数据 print(data.shape) # 查看数据大小
0 1 2 3 4 5 6 7 8 9 ... 14 \ 0 14 14 14 13 12 12 12 12 15 17 ... 22 1 1.8 1.8 1.8 1.8 1.8 1.8 1.8 1.8 1.8 1.8 ... 1.8 2 0.51 0.41 0.39 0.37 0.35 0.3 0.37 0.47 0.78 0.74 ... 0.37 3 0.2 0.15 0.13 0.12 0.11 0.06 0.1 0.13 0.26 0.23 ... 0.1 4 0.9 0.6 0.5 1.7 1.8 1.5 1.9 2.2 6.6 7.9 ... 2.5 ... ... ... ... ... ... ... ... ... ... ... ... ... 4315 1.8 1.8 1.8 1.8 1.8 1.7 1.7 1.8 1.8 1.8 ... 1.8 4316 46 13 61 44 55 68 66 70 66 85 ... 59 4317 36 55 72 327 74 52 59 83 106 105 ... 18 4318 1.9 2.4 1.9 2.8 2.3 1.9 2.1 3.7 2.8 3.8 ... 2.3 4319 0.7 0.8 1.8 1 1.9 1.7 2.1 2 2 1.7 ... 1.3 15 16 17 18 19 20 21 22 23 0 22 21 19 17 16 15 15 15 15 1 1.8 1.8 1.8 1.8 1.8 1.8 1.8 1.8 1.8 2 0.37 0.47 0.69 0.56 0.45 0.38 0.35 0.36 0.32 3 0.13 0.14 0.23 0.18 0.12 0.1 0.09 0.1 0.08 4 2.2 2.5 2.3 2.1 1.9 1.5 1.6 1.8 1.5 ... ... ... ... ... ... ... ... ... ... 4315 1.8 2 2.1 2 1.9 1.9 1.9 2 2 4316 308 327 21 100 109 108 114 108 109 4317 311 52 54 121 97 107 118 100 105 4318 2.6 1.3 1 1.5 1 1.7 1.5 2 2 4319 1.7 0.7 0.4 1.1 1.4 1.3 1.6 1.8 2 [4320 rows x 24 columns] (4320, 24)
取需要的数值部分,将 ‘RAINFALL’ 栏位全部补 0。
对数据进行查看后发现有缺失的数据,对缺失数据进行处理,填补缺失值。
读取的数据中有部分解释性的内容,我们不需要,可以进行提取直接忽略
data.iloc[:,:]
该函数用于处理数据,把我们需要的部分进行切割获取
data[data == 'xxx'] = 0
把xxx的内容替换成0
data = data.iloc[:, 3:] # 从列表的第4路项开始取(不要那些没有意义的数字) print(data) #查看数据 print(data.shape) #查看数据大小 print(type(data)) data[data == 'NR'] = 0 # 把'NR'项装换成0 raw_data = data.to_numpy() # 把数据转换成numpy数组
0 1 2 3 4 5 6 7 8 9 ... 14 \ 0 14 14 14 13 12 12 12 12 15 17 ... 22 1 1.8 1.8 1.8 1.8 1.8 1.8 1.8 1.8 1.8 1.8 ... 1.8 2 0.51 0.41 0.39 0.37 0.35 0.3 0.37 0.47 0.78 0.74 ... 0.37 3 0.2 0.15 0.13 0.12 0.11 0.06 0.1 0.13 0.26 0.23 ... 0.1 4 0.9 0.6 0.5 1.7 1.8 1.5 1.9 2.2 6.6 7.9 ... 2.5 ... ... ... ... ... ... ... ... ... ... ... ... ... 4315 1.8 1.8 1.8 1.8 1.8 1.7 1.7 1.8 1.8 1.8 ... 1.8 4316 46 13 61 44 55 68 66 70 66 85 ... 59 4317 36 55 72 327 74 52 59 83 106 105 ... 18 4318 1.9 2.4 1.9 2.8 2.3 1.9 2.1 3.7 2.8 3.8 ... 2.3 4319 0.7 0.8 1.8 1 1.9 1.7 2.1 2 2 1.7 ... 1.3 15 16 17 18 19 20 21 22 23 0 22 21 19 17 16 15 15 15 15 1 1.8 1.8 1.8 1.8 1.8 1.8 1.8 1.8 1.8 2 0.37 0.47 0.69 0.56 0.45 0.38 0.35 0.36 0.32 3 0.13 0.14 0.23 0.18 0.12 0.1 0.09 0.1 0.08 4 2.2 2.5 2.3 2.1 1.9 1.5 1.6 1.8 1.5 ... ... ... ... ... ... ... ... ... ... 4315 1.8 2 2.1 2 1.9 1.9 1.9 2 2 4316 308 327 21 100 109 108 114 108 109 4317 311 52 54 121 97 107 118 100 105 4318 2.6 1.3 1 1.5 1 1.7 1.5 2 2 4319 1.7 0.7 0.4 1.1 1.4 1.3 1.6 1.8 2 [4320 rows x 24 columns] (4320, 24) <class 'pandas.core.frame.DataFrame'>
print(raw_data.shape) # 查看数组大小 print(type(raw_data)) # 查看类型
(4320, 24) <class 'numpy.ndarray'>
从原先的24*(18*20*12)
转换成12*18*(20*24)
month_data = {} for month in range(12): sample = np.empty([18, 480]) # 新建np数组大小是[18, 480]内容随机 for day in range(20): sample[:, day * 24 : (day + 1) * 24] = raw_data[18 * (20 * month + day) : 18 * (20 * month + day + 1), :] month_data[month] = sample
# print(len(month_data),len(month_data[0])) # 大小查看 # print(month_data) # 数据查看 print(month_data[month]) print(month_data[month].shape)
[[ 23. 23. 23. ... 13. 13. 13. ] [ 1.6 1.7 1.7 ... 1.8 1.8 1.8 ] [ 0.22 0.2 0.18 ... 0.51 0.57 0.56] ... [ 93. 50. 99. ... 118. 100. 105. ] [ 1.8 2.1 3.2 ... 1.5 2. 2. ] [ 1.3 0.9 1. ... 1.6 1.8 2. ]] (18, 480)
每个月会有 480hrs,每 9 小时形成一个 data,每个月会有 471 个 data,故总资料数为 471 * 12 笔,而每笔 data 有 9 * 18 的 features (一小时 18 个 features * 9 小时)。
- 471次/月 * 12个月就是我们得到的数据量,而每次的数据量是9小时 * 18种数据
对应的 target 则有 471 * 12 个(第 10 个小时的 PM2.5)
- target是下一个时刻的PM2.5的值
解析:
一个月的数据就是20天 * 24小时 共计 480小时的数据
按照9小时一组进行处理应该得到 480-9+1== 472
但是最后一组数据是没有最后对应的y的值的
所以:是472 - 1 == 471
- 为什么是9次一组不是10次???
题目要求:前九小时的所有观测数据当作 feature,第十小时的 PM2.5 当作 answer。
根据要求9次数据为训练集第10次的为比较值。 - 为什么不10个一组然后再做数据处理???
这个建议不是不行,但是如果10个一组就少了一点点的训练量
这个具体看自己理解。
x = np.empty([12 * 471, 18 * 9], dtype = float) y = np.empty([12 * 471, 1], dtype = float) for month in range(12): for day in range(20): for hour in range(24): if day == 19 and hour > 14: continue x[month * 471 + day * 24 + hour, :] = month_data[month][:,day * 24 + hour : day * 24 + hour + 9].reshape(1, -1) #vector dim:18*9 (9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9) y[month * 471 + day * 24 + hour, 0] = month_data[month][9, day * 24 + hour + 9] #value print(x) print(x.shape) print(y) print(y.shape)
[[14. 14. 14. ... 2. 2. 0.5] [14. 14. 13. ... 2. 0.5 0.3] [14. 13. 12. ... 0.5 0.3 0.8] ... [17. 18. 19. ... 1.1 1.4 1.3] [18. 19. 18. ... 1.4 1.3 1.6] [19. 18. 17. ... 1.3 1.6 1.8]] (5652, 162) [[30.] [41.] [44.] ... [17.] [24.] [29.]] (5652, 1)
求值的离散程度
计算x值的方差和平均值
然后计算每一个值的离散程度
np.mean(x, axis = 0)
:平均值
np.std(x, axis = 0)
: 方差
离散程度 = 平均差/方差
mean_x = np.mean(x, axis = 0) # 18 * 9 print(mean_x.shape) std_x = np.std(x, axis = 0) # 18 * 9 print(std_x.shape) for i in range(len(x)): # 12 * 471 for j in range(len(x[0])): # 18 * 9 if std_x[j] != 0: x[i][j] = (x[i][j] - mean_x[j]) / std_x[j] x.shape
(162,) (162,) (5652, 162)
数据集生成
把数据集安装一定的比例进行区分,一部分生成训练集一部分生成测试集(建议8:2)
注:查看对应的数据是否一样!
import math x_train_set = x[: math.floor(len(x) * 0.8), :] y_train_set = y[: math.floor(len(y) * 0.8), :] x_validation = x[math.floor(len(x) * 0.8): , :] y_validation = y[math.floor(len(y) * 0.8): , :] print(x_train_set) print(y_train_set) print(x_validation) print(y_validation) print(len(x_train_set)) print(len(y_train_set)) print(len(x_validation)) print(len(y_validation))
[[-1.35825331 -1.35883937 -1.359222 ... 0.26650729 0.2656797 -1.14082131] [-1.35825331 -1.35883937 -1.51819928 ... 0.26650729 -1.13963133 -1.32832904] [-1.35825331 -1.51789368 -1.67717656 ... -1.13923451 -1.32700613 -0.85955971] ... [ 0.86929969 0.70886668 0.38952809 ... 1.39110073 0.2656797 -0.39079039] [ 0.71018876 0.39075806 0.07157353 ... 0.26650729 -0.39013211 -0.39079039] [ 0.3919669 0.07264944 0.07157353 ... -0.38950555 -0.39013211 -0.85955971]] [[30.] [41.] [44.] ... [ 7.] [ 5.] [14.]] [[ 0.07374504 0.07264944 0.07157353 ... -0.38950555 -0.85856912 -0.57829812] [ 0.07374504 0.07264944 0.23055081 ... -0.85808615 -0.57750692 0.54674825] [ 0.07374504 0.23170375 0.23055081 ... -0.57693779 0.54674191 -0.1095288 ] ... [-0.88092053 -0.72262212 -0.56433559 ... -0.57693779 -0.29644471 -0.39079039] [-0.7218096 -0.56356781 -0.72331287 ... -0.29578943 -0.39013211 -0.1095288 ] [-0.56269867 -0.72262212 -0.88229015 ... -0.38950555 -0.10906991 0.07797893]] [[13.] [24.] [22.] ... [17.] [24.] [29.]] 4521 4521 1131 1131
因为常数项的存在,所以 dimension (dim) 需要多加一栏;eps 项是避免 adagrad 的分母为 0 而加的极小数值。
每一个 dimension (dim) 会对应到各自的 gradient, weight (w),透过一次次的 iteration (iter_time) 学习。
采用均方根误差
dim = 18 * 9 + 1 # 18个数据*9次+1(常量)个 w = np.zeros([dim, 1]) # 生成数据是0的数组 x = np.concatenate((np.ones([12 * 471, 1]), x), axis = 1).astype(float) # 拼接1和x数组 learning_rate = 100 # 学习率 iter_time = 1000 # 学习次数 adagrad = np.zeros([dim, 1]) # 生成数据是0的数组 eps = 0.0000000001 for t in range(iter_time): loss = np.sqrt(np.sum(np.power(np.dot(x, w) - y, 2))/471/12) # rmse if(t%100==0): # 100轮输出 print(str(t) + ":" + str(loss)) gradient = 2 * np.dot(x.transpose(), np.dot(x, w) - y) # dim*1 adagrad += gradient ** 2 w = w - learning_rate * gradient / np.sqrt(adagrad + eps) np.save('weight.npy', w) # 保存文件 w
0:27.071214829194115 100:33.78905859777454 200:19.913751298197095 300:13.531068193689693 400:10.645466158446172 500:9.277353455475065 600:8.518042045956502 700:8.014061987588425 800:7.636756824775692 900:7.336563740371125 array([[ 2.13740269e+01], [ 3.58888909e+00], [ 4.56386323e+00], [ 2.16307023e+00], [-6.58545223e+00], [-3.38885580e+01], [ 3.22235518e+01], [ 3.49340354e+00], [-4.60308671e+00], [-1.02374754e+00], [-3.96791501e-01], [-1.06908800e-01], [ 2.22488184e-01], [ 8.99634117e-02], [ 1.31243105e-01], [ 2.15894989e-02], [-1.52867263e-01], [ 4.54087776e-02], [ 5.20999235e-01], [ 1.60824213e-01], [-3.17709451e-02], [ 1.28529025e-02], [-1.76839437e-01], [ 1.71241371e-01], [-1.31190032e-01], [-3.51614451e-02], [ 1.00826192e-01], [ 3.45018257e-01], [ 4.00130315e-02], [ 2.54331382e-02], [-5.04425219e-01], [ 3.71483018e-01], [ 8.46357671e-01], [-8.11920428e-01], [-8.00217575e-02], [ 1.52737711e-01], [ 2.64915130e-01], [-5.19860416e-02], [-2.51988315e-01], [ 3.85246517e-01], [ 1.65431451e-01], [-7.83633314e-02], [-2.89457231e-01], [ 1.77615023e-01], [ 3.22506948e-01], [-4.59955256e-01], [-3.48635358e-02], [-5.81764363e-01], [-6.43394528e-02], [-6.32876949e-01], [ 6.36624507e-02], [ 8.31592506e-02], [-4.45157961e-01], [-2.34526366e-01], [ 9.86608594e-01], [ 2.65230652e-01], [ 3.51938093e-02], [ 3.07464334e-01], [-1.04311239e-01], [-6.49166901e-02], [ 2.11224757e-01], [-2.43159815e-01], [-1.31285604e-01], [ 1.09045810e+00], [-3.97913710e-02], [ 9.19563678e-01], [-9.44824150e-01], [-5.04137735e-01], [ 6.81272939e-01], [-1.34494828e+00], [-2.68009542e-01], [ 4.36204342e-02], [ 1.89619513e+00], [-3.41873873e-01], [ 1.89162461e-01], [ 1.73251268e-02], [ 3.14431930e-01], [-3.40828467e-01], [ 4.92385651e-01], [ 9.29634214e-02], [-4.50983589e-01], [ 1.47456584e+00], [-3.03417236e-02], [ 7.71229328e-02], [ 6.38314494e-01], [-7.93287087e-01], [ 8.82877506e-01], [ 3.18965610e+00], [-5.75671706e+00], [ 1.60748945e+00], [ 1.36142440e+01], [ 1.50029111e-01], [-4.78389603e-02], [-6.29463755e-02], [-2.85383032e-02], [-3.01562821e-01], [ 4.12058013e-01], [-6.77534154e-02], [-1.00985479e-01], [-1.68972973e-01], [ 1.64093233e+00], [ 1.89670371e+00], [ 3.94713816e-01], [-4.71231449e+00], [-7.42760774e+00], [ 6.19781936e+00], [ 3.53986244e+00], [-9.56245861e-01], [-1.04372792e+00], [-4.92863713e-01], [ 6.31608790e-01], [-4.85175956e-01], [ 2.58400216e-01], [ 9.43846795e-02], [-1.29323184e-01], [-3.81235287e-01], [ 3.86819479e-01], [ 4.04211627e-01], [ 3.75568914e-01], [ 1.83512261e-01], [-8.01417708e-02], [-3.10188597e-01], [-3.96124612e-01], [ 3.66227853e-01], [ 1.79488593e-01], [-3.14477051e-01], [-2.37611443e-01], [ 3.97076104e-02], [ 1.38775912e-01], [-3.84015069e-02], [-5.47557119e-02], [ 4.19975207e-01], [ 4.46120687e-01], [-4.31074826e-01], [-8.74450768e-02], [-5.69534264e-02], [-7.23980157e-02], [-1.39880128e-02], [ 1.40489658e-01], [-2.44952334e-01], [ 1.83646770e-01], [-1.64135512e-01], [-7.41216452e-02], [-9.71414213e-02], [ 1.98829041e-02], [-4.46965919e-01], [-2.63440959e-01], [ 1.52924043e-01], [ 6.52532847e-02], [ 7.06818266e-01], [ 9.73757051e-02], [-3.35687787e-01], [-2.26559165e-01], [-3.00117086e-01], [ 1.24185231e-01], [ 4.18872344e-01], [-2.51891946e-01], [-1.29095731e-01], [-5.57512471e-01], [ 8.76239582e-02], [ 3.02594902e-01], [-4.23463160e-01], [ 4.89922051e-01]])
加载 test data,并且以相似于训练资料预先处理和特徵萃取的方式处理,使 test data 形成 240 个维度为 18 * 9 + 1 的资料。
# 对tast_data做同等处理 testdata = pd.read_csv('work/hw1_data/test.csv', header = None, encoding = 'big5') test_data = testdata.iloc[:, 2:] test_data[test_data == 'NR'] = 0 test_data = test_data.to_numpy() test_x = np.empty([240, 18*9], dtype = float) for i in range(240): test_x[i, :] = test_data[18 * i: 18* (i + 1), :].reshape(1, -1) for i in range(len(test_x)): for j in range(len(test_x[0])): if std_x[j] != 0: test_x[i][j] = (test_x[i][j] - mean_x[j]) / std_x[j] test_x = np.concatenate((np.ones([240, 1]), test_x), axis = 1).astype(float) test_x
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/ipykernel_launcher.py:3: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy This is separate from the ipykernel package so we can avoid doing imports until /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/pandas/core/frame.py:3215: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy self._where(-key, value, inplace=True) array([[ 1. , -0.24447681, -0.24545919, ..., -0.67065391, -1.04594393, 0.07797893], [ 1. , -1.35825331, -1.51789368, ..., 0.17279117, -0.10906991, -0.48454426], [ 1. , 1.5057434 , 1.34508393, ..., -1.32666675, -1.04594393, -0.57829812], ..., [ 1. , 0.3919669 , 0.54981237, ..., 0.26650729, -0.20275731, 1.20302531], [ 1. , -1.8355861 , -1.8360023 , ..., -1.04551839, -1.13963133, -1.14082131], [ 1. , -1.35825331, -1.35883937, ..., 2.98427476, 3.26367657, 1.76554849]])
有了 weight 和测试资料就可以预测 target。
np.dot(test_x, w)
:预测!
w = np.load('work/weight.npy') # 读取文档 ans_y = np.dot(test_x, w) # 预测 ans_y
array([[ 5.17496040e+00], [ 1.83062143e+01], [ 2.04912181e+01], [ 1.15239429e+01], [ 2.66160568e+01], [ 2.05313481e+01], [ 2.19065510e+01], [ 3.17364687e+01], [ 1.33916741e+01], [ 6.44564665e+01], [ 2.02645688e+01], [ 1.53585761e+01], [ 6.85894728e+01], [ 4.84281137e+01], [ 1.87023338e+01], [ 1.01885957e+01], [ 3.07403629e+01], [ 7.11322178e+01], [-4.13051739e+00], [ 1.82356940e+01], [ 3.85789223e+01], [ 7.13115197e+01], [ 7.41034816e+00], [ 1.87179553e+01], [ 1.49372503e+01], [ 3.67197367e+01], [ 1.79616970e+01], [ 7.57894629e+01], [ 1.23093102e+01], [ 5.62953517e+01], [ 2.51131609e+01], [ 4.61024867e+00], [ 2.48377055e+00], [ 2.47594223e+01], [ 3.04802805e+01], [ 3.84639307e+01], [ 4.42023106e+01], [ 3.00868360e+01], [ 4.04736750e+01], [ 2.92264799e+01], [ 5.60645605e+00], [ 3.86660161e+01], [ 3.46102134e+01], [ 4.83896975e+01], [ 1.47572477e+01], [ 3.44668201e+01], [ 2.74831069e+01], [ 1.20008794e+01], [ 2.13780362e+01], [ 2.85444031e+01], [ 2.01655138e+01], [ 1.07966781e+01], [ 2.21710358e+01], [ 5.34462631e+01], [ 1.22195811e+01], [ 4.33009685e+01], [ 3.21823351e+01], [ 2.25672175e+01], [ 5.67395142e+01], [ 2.07450529e+01], [ 1.50288546e+01], [ 3.98553016e+01], [ 1.29753407e+01], [ 5.17416596e+01], [ 1.87833696e+01], [ 1.23487528e+01], [ 1.56336237e+01], [-5.88714707e-02], [ 4.15080111e+01], [ 3.15487475e+01], [ 1.86042512e+01], [ 3.74768197e+01], [ 5.65203907e+01], [ 6.58787719e+00], [ 1.22293397e+01], [ 5.20369640e+00], [ 4.79273751e+01], [ 1.30207057e+01], [ 1.71103017e+01], [ 2.06032345e+01], [ 2.12844816e+01], [ 3.86929353e+01], [ 3.00207167e+01], [ 8.87674067e+01], [ 3.59847002e+01], [ 2.67569136e+01], [ 2.39635168e+01], [ 3.27472428e+01], [ 2.21890438e+01], [ 2.09921589e+01], [ 2.95559943e+01], [ 4.09921689e+01], [ 8.62511781e+00], [ 3.23214718e+01], [ 4.65980444e+01], [ 2.28840708e+01], [ 3.15181297e+01], [ 1.11982335e+01], [ 2.85274366e+01], [ 2.91150680e-01], [ 1.79669611e+01], [ 2.71241639e+01], [ 1.13982328e+01], [ 1.64264269e+01], [ 2.34252610e+01], [ 4.06160827e+01], [ 2.58641250e+01], [ 5.42273695e+00], [ 1.07949211e+01], [ 7.28621369e+01], [ 4.80228371e+01], [ 1.57468083e+01], [ 2.46704106e+01], [ 1.28277933e+01], [ 1.01580576e+01], [ 2.72692233e+01], [ 2.92087386e+01], [ 8.83533962e+00], [ 2.00510881e+01], [ 2.02123337e+01], [ 7.99060093e+01], [ 1.80616143e+01], [ 3.05428093e+01], [ 2.59807924e+01], [ 5.21257727e+00], [ 3.03556973e+01], [ 7.76832289e+00], [ 1.53282683e+01], [ 2.26663657e+01], [ 6.27420542e+01], [ 1.89507804e+01], [ 1.90763556e+01], [ 6.13715741e+01], [ 1.58845621e+01], [ 1.34094181e+01], [ 8.48772484e-01], [ 7.83499672e+00], [ 5.70128290e+01], [ 2.56079968e+01], [ 4.96170473e+00], [ 3.64148790e+01], [ 2.87900067e+01], [ 4.91941210e+01], [ 4.03068699e+01], [ 1.33161806e+01], [ 2.76610119e+01], [ 1.71580275e+01], [ 4.96872626e+01], [ 2.30302723e+01], [ 3.92409365e+01], [ 1.31967539e+01], [ 5.94889370e+00], [ 2.58216090e+01], [ 8.25863421e+00], [ 1.91463205e+01], [ 4.31824865e+01], [ 6.71784358e+00], [ 3.38696152e+01], [ 1.53699378e+01], [ 1.69390450e+01], [ 3.78853368e+01], [ 1.92024845e+01], [ 9.05950472e+00], [ 1.02833996e+01], [ 4.86724471e+01], [ 3.05877162e+01], [ 2.47740990e+00], [ 1.28116039e+01], [ 7.03247898e+01], [ 1.48409677e+01], [ 6.88655876e+01], [ 4.27419924e+01], [ 2.40002615e+01], [ 2.34207249e+01], [ 6.16721244e+01], [ 2.54942028e+01], [ 1.90048098e+01], [ 3.48866829e+01], [ 9.40231340e+00], [ 2.95200113e+01], [ 1.45739659e+01], [ 9.12556314e+00], [ 5.28125840e+01], [ 4.50395380e+01], [ 1.74524347e+01], [ 3.84939353e+01], [ 2.70389191e+01], [ 6.55817097e+01], [ 7.03730638e+00], [ 5.27144771e+01], [ 3.82064593e+01], [ 2.11698011e+01], [ 3.02475569e+01], [ 2.71442299e+00], [ 1.99329326e+01], [-3.41333234e+00], [ 3.24459994e+01], [ 1.05829730e+01], [ 2.17752257e+01], [ 6.24652921e+01], [ 2.41329437e+01], [ 2.62012396e+01], [ 6.37444772e+01], [ 2.83429777e+00], [ 1.43792470e+01], [ 9.36985073e+00], [ 9.88116661e+00], [ 3.49494536e+00], [ 1.22608049e+02], [ 2.10835130e+01], [ 1.75322206e+01], [ 2.01830983e+01], [ 3.63931322e+01], [ 3.49351512e+01], [ 1.88303127e+01], [ 3.83445555e+01], [ 7.79166341e+01], [ 1.79532355e+00], [ 1.34458279e+01], [ 3.61311556e+01], [ 1.51504035e+01], [ 1.29418483e+01], [ 1.13125241e+02], [ 1.52246047e+01], [ 1.48240260e+01], [ 5.92673537e+01], [ 1.05836953e+01], [ 2.09930626e+01], [ 9.78936588e+00], [ 4.77118001e+00], [ 4.79278069e+01], [ 1.23994384e+01], [ 4.81464766e+01], [ 4.04663804e+01], [ 1.69405903e+01], [ 4.12665445e+01], [ 6.90278920e+01], [ 4.03462492e+01], [ 1.43137440e+01], [ 1.57707266e+01]])
保存预测文件
import csv with open('work/submit.csv', mode='w', newline='') as submit_file: csv_writer = csv.writer(submit_file) header = ['id', 'value'] print(header) csv_writer.writerow(header) for i in range(240): row = ['id_' + str(i), ans_y[i][0]] csv_writer.writerow(row) print(row)
['id', 'value'] ['id_0', 5.174960398984736] ['id_1', 18.30621425352788] ['id_2', 20.491218094180528] ['id_3', 11.523942869805332] ['id_4', 26.616056752306132] ['id_5', 20.531348081761223] ['id_6', 21.906551018797376] ['id_7', 31.736468747068834] ['id_8', 13.391674055111736] ['id_9', 64.45646650291954] ['id_10', 20.264568836159437] ['id_11', 15.35857607736122] ['id_12', 68.58947276926726] ['id_13', 48.428113747457196] ['id_14', 18.702333824193218] ['id_15', 10.188595737466716] ['id_16', 30.74036285982044] ['id_17', 71.13221776355108] ['id_18', -4.130517391262444] ['id_19', 18.235694016428695] ['id_20', 38.578922275007756] ['id_21', 71.31151972531332] ['id_22', 7.410348162634086] ['id_23', 18.717955330321395] ['id_24', 14.93725026008458] ['id_25', 36.719736694705325] ['id_26', 17.96169700566271] ['id_27', 75.78946287210539] ['id_28', 12.309310248614484] ['id_29', 56.2953517396496] ['id_30', 25.113160865661484] ['id_31', 4.610248674094053] ['id_32', 2.4837705545150315] ['id_33', 24.75942226132128] ['id_34', 30.480280465591196] ['id_35', 38.46393074642666] ['id_36', 44.20231060933005] ['id_37', 30.08683601986601] ['id_38', 40.47367501574008] ['id_39', 29.22647990231738] ['id_40', 5.606456054343949] ['id_41', 38.666016078789596] ['id_42', 34.61021343187721] ['id_43', 48.38969750738482] ['id_44', 14.757247666944172] ['id_45', 34.46682011087208] ['id_46', 27.48310687418436] ['id_47', 12.000879378154043] ['id_48', 21.378036151603794] ['id_49', 28.54440309166328] ['id_50', 20.16551381841159] ['id_51', 10.796678149746501] ['id_52', 22.171035755750125] ['id_53', 53.446263109352266] ['id_54', 12.21958112161002] ['id_55', 43.30096845517155] ['id_56', 32.1823351032854] ['id_57', 22.5672175145708] ['id_58', 56.73951416554704] ['id_59', 20.745052945295473] ['id_60', 15.028854557473265] ['id_61', 39.8553015903851] ['id_62', 12.975340680728284] ['id_63', 51.74165959283004] ['id_64', 18.783369632539877] ['id_65', 12.348752842777712] ['id_66', 15.633623653541925] ['id_67', -0.05887147068500154] ['id_68', 41.50801107307596] ['id_69', 31.548747530656026] ['id_70', 18.604251157547075] ['id_71', 37.4768197248807] ['id_72', 56.52039065762305] ['id_73', 6.58787719352195] ['id_74', 12.229339737435051] ['id_75', 5.203696404134638] ['id_76', 47.92737510380059] ['id_77', 13.020705685594661] ['id_78', 17.110301693903597] ['id_79', 20.603234531002048] ['id_80', 21.284481560784613] ['id_81', 38.69293529051181] ['id_82', 30.020716675725847] ['id_83', 88.76740666723548] ['id_84', 35.984700239668264] ['id_85', 26.756913553477187] ['id_86', 23.963516843564403] ['id_87', 32.747242828083074] ['id_88', 22.18904375531994] ['id_89', 20.992158853626545] ['id_90', 29.555994316645446] ['id_91', 40.99216886651781] ['id_92', 8.625117809911558] ['id_93', 32.3214718088779] ['id_94', 46.59804436536759] ['id_95', 22.88407082672354] ['id_96', 31.518129728251655] ['id_97', 11.19823347976612] ['id_98', 28.527436642529608] ['id_99', 0.2911506800896443] ['id_100', 17.96696107953969] ['id_101', 27.124163929470143] ['id_102', 11.398232780652847] ['id_103', 16.426426865673527] ['id_104', 23.42526104692219] ['id_105', 40.6160826705684] ['id_106', 25.8641250265604] ['id_107', 5.422736951672389] ['id_108', 10.794921122256104] ['id_109', 72.86213692992126] ['id_110', 48.022837059481375] ['id_111', 15.746808276902996] ['id_112', 24.67041061417795] ['id_113', 12.827793326536716] ['id_114', 10.158057570240526] ['id_115', 27.269223342020982] ['id_116', 29.208738577932458] ['id_117', 8.835339619930767] ['id_118', 20.05108813712978] ['id_119', 20.212333743764248] ['id_120', 79.9060092987056] ['id_121', 18.061614288263595] ['id_122', 30.542809341304345] ['id_123', 25.98079237772804] ['id_124', 5.212577268164767] ['id_125', 30.355697305856214] ['id_126', 7.768322888914637] ['id_127', 15.328268255393336] ['id_128', 22.66636571769797] ['id_129', 62.742054211090085] ['id_130', 18.950780367987996] ['id_131', 19.076355630838545] ['id_132', 61.37157409163711] ['id_133', 15.884562052629718] ['id_134', 13.409418077705558] ['id_135', 0.8487724836112842] ['id_136', 7.834996717304126] ['id_137', 57.01282901179679] ['id_138', 25.607996751813804] ['id_139', 4.9617047292420855] ['id_140', 36.414879039062775] ['id_141', 28.790006721975917] ['id_142', 49.19412096197634] ['id_143', 40.3068698557345] ['id_144', 13.316180593982658] ['id_145', 27.661011875229164] ['id_146', 17.158027524366766] ['id_147', 49.68726256929682] ['id_148', 23.03027229160478] ['id_149', 39.240936524842766] ['id_150', 13.19675388941254] ['id_151', 5.948893701039413] ['id_152', 25.82160897630425] ['id_153', 8.258634214291634] ['id_154', 19.146320517225597] ['id_155', 43.18248652651674] ['id_156', 6.717843578093033] ['id_157', 33.869615246810646] ['id_158', 15.3699378469818] ['id_159', 16.939044973551923] ['id_160', 37.88533679463485] ['id_161', 19.202484541054467] ['id_162', 9.059504715654725] ['id_163', 10.283399610648509] ['id_164', 48.672447125698284] ['id_165', 30.58771621323082] ['id_166', 2.4774098975321657] ['id_167', 12.811603937805932] ['id_168', 70.32478980976464] ['id_169', 14.840967694067068] ['id_170', 68.8655875667886] ['id_171', 42.74199244486634] ['id_172', 24.000261542920168] ['id_173', 23.420724860321446] ['id_174', 61.672124435682356] ['id_175', 25.494202845059192] ['id_176', 19.004809786869096] ['id_177', 34.88668288189683] ['id_178', 9.40231339837975] ['id_179', 29.520011314408027] ['id_180', 14.573965885700483] ['id_181', 9.125563143203598] ['id_182', 52.81258399813187] ['id_183', 45.03953799438962] ['id_184', 17.452434679183295] ['id_185', 38.49393527971433] ['id_186', 27.03891909264382] ['id_187', 65.58170967424583] ['id_188', 7.0373063807695795] ['id_189', 52.71447713411572] ['id_190', 38.20645933704977] ['id_191', 21.16980105955784] ['id_192', 30.247556879488393] ['id_193', 2.714422989716304] ['id_194', 19.93293258764082] ['id_195', -3.413332337603944] ['id_196', 32.44599940281316] ['id_197', 10.582973029979941] ['id_198', 21.77522570725845] ['id_199', 62.465292065677886] ['id_200', 24.13294368731649] ['id_201', 26.201239647400964] ['id_202', 63.74447723440287] ['id_203', 2.83429777412905] ['id_204', 14.37924698697884] ['id_205', 9.369850731753894] ['id_206', 9.881166613595411] ['id_207', 3.4949453589721426] ['id_208', 122.6080493792178] ['id_209', 21.083513014480573] ['id_210', 17.53222059945511] ['id_211', 20.183098344597003] ['id_212', 36.39313221228185] ['id_213', 34.93515120529068] ['id_214', 18.83031266145864] ['id_215', 38.34455552272332] ['id_216', 77.91663413807038] ['id_217', 1.7953235508882215] ['id_218', 13.445827939135775] ['id_219', 36.131155590412135] ['id_220', 15.150403498166307] ['id_221', 12.941848334417926] ['id_222', 113.12524093786391] ['id_223', 15.224604677934382] ['id_224', 14.824025968612034] ['id_225', 59.267353688540446] ['id_226', 10.583695290718481] ['id_227', 20.993062563532174] ['id_228', 9.789365880830381] ['id_229', 4.77118000870597] ['id_230', 47.92780690481291] ['id_231', 12.399438394751039] ['id_232', 48.14647656264414] ['id_233', 40.46638039656415] ['id_234', 16.94059027033294] ['id_235', 41.26654448941875] ['id_236', 69.02789203372899] ['id_237', 40.34624924412241] ['id_238', 14.313743982871129] ['id_239', 15.770726634219828]
传说中的飞桨社区最菜代码人,让我们一起努力!
记住:三岁出品必是精品 (不要脸系列)