线性回归
1. 单变量的线性回归
import pandas as pd import numpy as np import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签 plt.rcParams['axes.unicode_minus']=False #用来正常显示负号
1.1 数据读取
data=pd.read_csv("data/regress_data1.csv") data.head()
人口 | 收益 | |
0 | 6.1101 | 17.5920 |
1 | 5.5277 | 9.1302 |
2 | 8.5186 | 13.6620 |
3 | 7.0032 | 11.8540 |
4 | 5.8598 | 6.8233 |
#可视化人口与收益之间的关系 data.plot(kind="scatter",x="人口",y="收益") plt.xlabel("人口",fontsize=10) plt.ylabel("收益",fontsize=10) plt.title("人口与收益之间的关系")
Text(0.5, 1.0, '人口与收益之间的关系')
1.2 训练数据的准备
data.insert(0,"ones",1) data
ones | 人口 | 收益 | |
0 | 1 | 6.1101 | 17.59200 |
1 | 1 | 5.5277 | 9.13020 |
2 | 1 | 8.5186 | 13.66200 |
3 | 1 | 7.0032 | 11.85400 |
4 | 1 | 5.8598 | 6.82330 |
... | ... | ... | ... |
92 | 1 | 5.8707 | 7.20290 |
93 | 1 | 5.3054 | 1.98690 |
94 | 1 | 8.2934 | 0.14454 |
95 | 1 | 13.3940 | 9.05510 |
96 | 1 | 5.4369 | 0.61705 |
97 rows × 3 columns
col_num=data.shape[1] m=data.shape[0] #训练集中的特征 X=data.iloc[:,:col_num-1] #训练集中的标签 y=data.iloc[:,col_num-1]
X=X.values y=y.values
X.shape,y.shape
((97, 2), (97,))
y=y.reshape((m,1)) y.shape
(97, 1)
#初始化权重向量 w=np.zeros((col_num-1,1)) w.shape
(2, 1)
1.3 假设函数定义–假设函数是为了去预测
#估计yhat def h(X,w): #X的维度m,col_num-1, w的维度col_num-1,1 temp=X@w return temp
1.4 损失函数的定义
#定义MSE损失,均方损失函数 def cost(X,y,w): temp=h(X,w) cost=np.sum(np.square(temp-y))/(2*m) return cost
def computeCost(X,y,w): inner = np.power(((X @ w) - y), 2)# (m,n) @ (n, 1) -> (n, 1) # return np.sum(inner)/(2 * len(X)) return np.sum(inner) / (2*m)
cost(X,y,w)
32.072733877455676
error=h(X,w)-y error.shape
(97, 1)
x1=np.array([1,2]).reshape(2,1) x2=np.array([3,4]).reshape(2,1) np.multiply(x1,x2)
array([[3], [8]])
X[:,1].shape
(97,)
X.shape,w.shape,y.shape
((97, 2), (2, 1), (97, 1))
h(X,w)-y
array([[-17.592 ], [ -9.1302 ], [-13.662 ], [-11.854 ], [ -6.8233 ], [-11.886 ], [ -4.3483 ], [-12. ], [ -6.5987 ], [ -3.8166 ], [ -3.2522 ], [-15.505 ], [ -3.1551 ], [ -7.2258 ], [ -0.71618], [ -3.5129 ], [ -5.3048 ], [ -0.56077], [ -3.6518 ], [ -5.3893 ], [ -3.1386 ], [-21.767 ], [ -4.263 ], [ -5.1875 ], [ -3.0825 ], [-22.638 ], [-13.501 ], [ -7.0467 ], [-14.692 ], [-24.147 ], [ 1.22 ], [ -5.9966 ], [-12.134 ], [ -1.8495 ], [ -6.5426 ], [ -4.5623 ], [ -4.1164 ], [ -3.3928 ], [-10.117 ], [ -5.4974 ], [ -0.55657], [ -3.9115 ], [ -5.3854 ], [ -2.4406 ], [ -6.7318 ], [ -1.0463 ], [ -5.1337 ], [ -1.844 ], [ -8.0043 ], [ -1.0179 ], [ -6.7504 ], [ -1.8396 ], [ -4.2885 ], [ -4.9981 ], [ -1.4233 ], [ 1.4211 ], [ -2.4756 ], [ -4.6042 ], [ -3.9624 ], [ -5.4141 ], [ -5.1694 ], [ 0.74279], [-17.929 ], [-12.054 ], [-17.054 ], [ -4.8852 ], [ -5.7442 ], [ -7.7754 ], [ -1.0173 ], [-20.992 ], [ -6.6799 ], [ -4.0259 ], [ -1.2784 ], [ -3.3411 ], [ 2.6807 ], [ -0.29678], [ -3.8845 ], [ -5.7014 ], [ -6.7526 ], [ -2.0576 ], [ -0.47953], [ -0.20421], [ -0.67861], [ -7.5435 ], [ -5.3436 ], [ -4.2415 ], [ -6.7981 ], [ -0.92695], [ -0.152 ], [ -2.8214 ], [ -1.8451 ], [ -4.2959 ], [ -7.2029 ], [ -1.9869 ], [ -0.14454], [ -9.0551 ], [ -0.61705]])
np.multiply((h(X,w)-y).ravel(),X[:,1]).shape
(97,)
1.5 利用梯度下降算法来优化参数w
#超参数为I,学习率alpha,对所有样本 def gradient_descent(X,y,w,iter_num,alpha): temp=np.zeros((col_num-1,1)) cost_lst=[] for i in range(iter_num): error=h(X,w)-y for j in range(col_num-1): incre=np.multiply(error.ravel(),X[:,j].ravel()) temp[j,0]=w[j,0]-((alpha/m)*np.sum(incre)) w=temp cost_lst.append(cost(X,y,w)) return w,cost_lst
iter_num=200 alpha=0.003 w=np.zeros((col_num-1,1)) w,cost_lst=gradient_descent(X,y,w,iter_num,alpha)
w
array([[-0.32791203], [ 0.83460252]])
cost
<function __main__.cost(X, y, w)>
1.6 可视化误差曲线
plt.plot(range(iter_num),cost_lst,"r-+") plt.xlabel("迭代次数") plt.ylabel("误差") plt.show()
1.7 可视化回归线/回归平面
x=np.linspace(data["人口"].min(),data["人口"].max(),50) y1=w[0,0]*1+w[1,0]*x plt.plot(x,y1,"r-+",label="预测线") plt.scatter(data["人口"],data["收益"], label='训练数据') plt.xlabel("人口",fontsize=10) plt.ylabel("收益",fontsize=10) plt.title("人口与收益之间的关系") plt.show()
w
array([[-0.32791203], [ 0.83460252]])
总结:
- 数据准备
- 初始化w
- 定义了假设函数
- 定义了损失函数或者代价函数
- 定义梯度下降算法
- 可视化分析
1.2 单变量的线性回归–基于sklearn试试?
X.shape,y.shape
((97, 2), (97, 1))
import sklearn from sklearn import linear_model reg=linear_model.LinearRegression() reg.fit(X,y) reg.coef_
array([[0. , 1.19303364]]) • 1
w
array([[-0.32791203], [ 0.83460252]])
reg.intercept_
array([-3.89578088])
reg.get_params()
{'copy_X': True, 'fit_intercept': True, 'n_jobs': None, 'normalize': 'deprecated', 'positive': False}
reg.predict(X)-y
array([[-14.19822601], [ -6.4312488 ], [ -7.39480448], [ -7.39472766], [ -3.72814233], [ -5.78069914], [ 0.67551586], [ -5.66181898], [ -2.75622606], [ -1.68207302], [ -0.33492365], [ -2.50265234], [ -0.21002596], [ -1.09007678], [ 2.117584 ], [ -0.99087569], [ -1.60644452], [ 1.66383102], [ 0.12314824], [ -0.84937859], [ 0.34942365], [ -1.47998891], [ -1.60890687], [ -1.53603074], [ -0.33916795], [ -3.93175849], [ -2.09254529], [ 2.12958876], [ -2.86836958], [ -1.55385488], [ 3.59050903], [ -2.03100498], [ -4.99636713], [ 1.28383475], [ -0.64226232], [ 1.00673223], [ 1.6465002 ], [ -0.60007636], [ 1.30099898], [ -1.81336092], [ 1.99826273], [ 0.40377318], [ 4.68685703], [ 0.55183747], [ -1.29245052], [ 3.52022606], [ -2.9805617 ], [ 1.18148451], [ 2.05841276], [ 1.69763436], [ -1.65046859], [ 0.59688379], [ 0.67268159], [ 0.17687322], [ 2.23616258], [ 5.11170076], [ 1.11395081], [ -1.77162904], [ 3.24920096], [ 1.96858198], [ 1.46381825], [ 3.02608828], [ 3.56178204], [ 1.83596469], [ 1.66894398], [ -0.16942543], [ 0.2563525 ], [ 0.5407115 ], [ 1.64788834], [ -0.62028352], [ 1.51690814], [ 0.82862438], [ 1.9914178 ], [ 1.38386093], [ 4.78217995], [ 3.61930412], [ 1.21352255], [ -3.58846693], [ 1.60884678], [ 0.14027707], [ 2.45981748], [ 2.08994488], [ 3.00817305], [ 0.21510688], [ -1.46569296], [ 2.02402528], [ 0.25840658], [ 2.33785705], [ 2.53824205], [ -0.68114646], [ 1.06859725], [ 0.91903985], [ -4.09473826], [ 0.44683982], [ 5.85398435], [ 3.02861175], [ 1.97357374]])
reg.score(X,y)
0.7020315537841397
1.3 多变量线性回归
path = 'data/regress_data2.csv' data2 = pd.read_csv(path) data2.head()
面积 | 房间数 | 价格 | |
0 | 2104 | 3 | 399900 |
1 | 1600 | 3 | 329900 |
2 | 2400 | 3 | 369000 |
3 | 1416 | 2 | 232000 |
4 | 3000 | 4 | 539900 |
data2=(data2-data2.mean())/data2.std()
data2.head()
面积 | 房间数 | 价格 | |
0 | 0.130010 | -0.223675 | 0.475747 |
1 | -0.504190 | -0.223675 | -0.084074 |
2 | 0.502476 | -0.223675 | 0.228626 |
3 | -0.735723 | -1.537767 | -0.867025 |
4 | 1.257476 | 1.090417 | 1.595389 |
实验要求1 准备训练数据
data2.insert(0,"ones",1) col_num2=data2.shape[1] m2=data2.shape[0] X2=data2.iloc[:,:-1].values y2=data2.iloc[:,-1].values.reshape((data2.shape[0],1)) w2=np.zeros((X2.shape[1],1))
X2.shape,y2.shape,w2.shape
((47, 3), (47, 1), (3, 1))
实验要求2 调用前面的梯度下降算法
#定义MSE损失,均方损失函数 def cost2(X,y,w): temp=h(X,w) cost=np.sum(np.square(temp-y))/(2*m2) return cost #超参数为I,学习率alpha,对所有样本 def gradient_descent(X,y,w,iter_num,alpha): temp=np.zeros((col_num2-1,1)) cost_lst=[] for i in range(iter_num): error=h(X,w)-y for j in range(col_num2-1): incre=np.multiply(error.ravel(),X[:,j].ravel()) temp[j,0]=w[j,0]-((alpha/m2)*np.sum(incre)) w=temp cost_lst.append(cost2(X,y,w)) return w,cost_lst
iter_num2=1000 alpha2=0.01 w2,cost_lst2=gradient_descent(X2,y2,w2,iter_num2,alpha2)
w2
array([[-1.03191687e-16], [ 8.78503652e-01], [-4.69166570e-02]])
cost_lst2
[0.4805491041076719, 0.47198587701203876, 0.46366461618706284, 0.4555781400525299, 0.44771948335326117, 0.4400818906150644, 0.43265880979889004, 0.42544388614718714, 0.41843095621663473, 0.4116140420916035, 0.4049873457728717, 0.39854524373628347, 0.3922822816562035, 0.38619316928877434, 0.3802727755101314, 0.3745161235048873, 0.36891838610032585, 0.36347488124189714, 0.3581810676057273, 0.353032540343996, 0.34802502695915444, 0.3431543833030803, 0.33841658969738386, 0.3338077471711977, 0.3293240738128865, 0.32496190123222957, 0.32071767112972566, 0.3165879319697778, 0.3125693357546089, 0.3086586348958572, 0.3048526791808924, 0.301148412830983, 0.29754287164853055, 0.29403318025067643, 0.290616549386659, 0.2872902733363892, 0.2840517273877804, 0.2808983653904495, 0.2778277173834725, 0.2748373872949541, 0.27192505071123485, 0.2690884527136251, 0.26632540578062175, 0.26363378775362334, 0.2610115398642199, 0.2584566648211922, 0.25596722495541263, 0.2535413404208921, 0.2511771874502738, 0.24887299666312288, 0.2466270514254147, 0.2444376862586688, 0.2423032852972262, 0.24022228079221122, 0.23819315166076374, 0.23621442207916982, 0.23428466011856308, 0.23240247642190492, 0.2305665229209955, 0.22877549159230046, 0.22702811325042083, 0.2253231563780625, 0.2236594259914031, 0.22203576253978147, 0.22045104083867165, 0.21890416903493287, 0.2173940876033578, 0.2159197683735719, 0.21448021358636368, 0.21307445497855435, 0.21170155289554488, 0.21036059543069827, 0.20905069759074849, 0.2077710004864442, 0.20652067054766643, 0.20529889876227617, 0.20410489993797507, 0.20293791198648126, 0.20179719522934592, 0.20068203172475318, 0.1995917246146703, 0.19852559749172968, 0.1974829937852473, 0.19646327616579626, 0.19546582596777473, 0.1944900426294234, 0.1935353431497636, 0.19260116156194368, 0.1916869484224977, 0.1907921703160338, 0.189916309374886, 0.18905886281327441, 0.18821934247553765, 0.1873972743980089, 0.1865921983841233, 0.185803667592357, 0.1850312481366081, 0.1842745186986435, 0.18353307015224665, 0.18280650519871142, 0.18209443801333897, 0.18139649390260434, 0.1807123089716706, 0.18004152980193594, 0.17938381313831125, 0.17873882558593354, 0.17810624331602853, 0.17748575178064738, 0.1768770454360068, 0.17627982747417442, 0.17569380956284517, 0.17511871159296452, 0.17455426143396124, 0.17400019469635858, 0.17345625450154267, 0.17292219125846853, 0.17239776244709656, 0.1718827324083545, 0.1713768721404272, 0.170879959101184, 0.17039177701655678, 0.16991211569468956, 0.16944077084568412, 0.16897754390677372, 0.16852224187275908, 0.16807467713154944, 0.16763466730465196, 0.16720203509246181, 0.16677660812420697, 0.16635821881240656, 0.1659467042117072, 0.165541905881964, 0.16514366975543857, 0.16475184600798926, 0.16436628893413294, 0.16398685682586134, 0.1636134118550983, 0.16324581995968843, 0.1628839507328093, 0.16252767731570503, 0.16217687629363958, 0.16183142759497396, 0.16149121439327116, 0.1611561230123392, 0.1608260428341213, 0.1605008662093497, 0.16018048837087778, 0.15986480734960967, 0.15955372389294997, 0.1592471413856969, 0.1589449657733043, 0.1586471054874422, 0.1583534713737859, 0.1580639766219659, 0.15777853669761477, 0.15749706927644566, 0.15721949418030304, 0.1569457333151249, 0.15667571061075897, 0.15640935196257785, 0.15614658517483707, 0.15588733990572493, 0.1556315476140529, 0.15537914150753623, 0.15513005649261752, 0.15488422912578662, 0.15464159756635176, 0.15440210153061742, 0.15416568224742783, 0.15393228241503384, 0.15370184615924348, 0.15347431899281805, 0.15324964777607533, 0.15302778067866443, 0.15280866714247615, 0.1525922578456555, 0.15237850466768188, 0.15216736065548658, 0.15195877999057447, 0.1517527179571211, 0.1515491309110149, 0.15134797624981636, 0.15114921238360698, 0.1509527987067004, 0.15075869557019017, 0.15056686425530957, 0.15037726694757766, 0.15018986671170936, 0.15000462746726564, 0.1498215139650219, 0.1496404917640332, 0.14946152720937447, 0.14928458741053677, 0.14910964022045875, 0.14893665421517455, 0.14876559867406025, 0.14859644356065968, 0.1484291595040733, 0.1482637177808928, 0.14810009029766483, 0.1479382495738683, 0.14777816872538996, 0.1476198214484827, 0.1474631820041929, 0.1473082252032421, 0.14715492639134978, 0.14700326143498368, 0.14685320670752527, 0.14670473907583773, 0.14655783588722407, 0.14641247495676463, 0.1462686345550211, 0.14612629339609778, 0.14598543062604827, 0.14584602581161768, 0.14570805892930994, 0.14557151035477142, 0.14543636085247996, 0.1453025915657318, 0.14517018400691611, 0.1450391200480697, 0.14490938191170247, 0.14478095216188633, 0.14465381369559951, 0.14452794973431848, 0.1444033438158502, 0.14427997978639776, 0.14415784179285188, 0.14403691427530232, 0.14391718195976197, 0.14379862985109793, 0.14368124322616269, 0.14356500762711993, 0.14344990885495965, 0.14333593296319558, 0.1432230662517408, 0.1431112952609561, 0.14300060676586493, 0.14289098777053136, 0.14278242550259543, 0.14267490740796127, 0.14256842114563387, 0.14246295458269945, 0.142358495789446, 0.14225503303461925, 0.14215255478081018, 0.1420510496799703, 0.14195050656905087, 0.14185091446576267, 0.1417522625644519, 0.14165454023209023, 0.14155773700437435, 0.141461842581932, 0.14136684682663245, 0.1412727397579967, 0.14117951154970584, 0.14108715252620416, 0.14099565315939372, 0.1409050040654189, 0.14081519600153714, 0.14072621986307407, 0.1406380666804604, 0.14055072761634763, 0.1404641939628014, 0.14037845713856925, 0.1402935086864208, 0.14020934027055865, 0.14012594367409767, 0.14004331079661067, 0.1399614336517384, 0.1398803043648625, 0.13979991517083906, 0.13972025841179114, 0.13964132653495887, 0.1395631120906052, 0.13948560772997556, 0.1394088062033102, 0.13933270035790743, 0.1392572831362367, 0.13918254757409942, 0.13910848679883667, 0.13903509402758246, 0.138962362565561, 0.13889028580442658, 0.13881885722064558, 0.1387480703739184, 0.138677918905641, 0.1386083965374045, 0.13853949706953195, 0.1384712143796508, 0.13840354242130104, 0.13833647522257653, 0.13827000688480018, 0.1382041315812306, 0.1381388435558005, 0.13807413712188513, 0.13801000666110028, 0.1379464466221289, 0.13788345151957584, 0.13782101593284923, 0.1377591345050687, 0.13769780194199868, 0.13763701301100714, 0.13757676254004808, 0.13751704541666765, 0.13745785658703336, 0.1373991910549853, 0.1373410438811091, 0.13728341018182993, 0.1372262851285268, 0.13716966394666721, 0.13711354191496053, 0.13705791436453063, 0.13700277667810684, 0.1369481242892326, 0.13689395268149127, 0.13684025738774938, 0.13678703398941564, 0.13673427811571623, 0.1366819854429858, 0.13663015169397344, 0.13657877263716303, 0.1365278440861087, 0.1364773618987835, 0.13642732197694216, 0.13637772026549663, 0.13632855275190497, 0.13627981546557258, 0.13623150447726529, 0.13618361589853506, 0.1361361458811566, 0.13608909061657562, 0.1360424463353681, 0.13599620930670997, 0.13595037583785746, 0.13590494227363753, 0.13585990499594844, 0.1358152604232695, 0.1357710050101804, 0.13572713524689023, 0.1356836476587745, 0.1356405388059217, 0.13559780528268783, 0.1355554437172595, 0.13551345077122479, 0.13547182313915254, 0.13543055754817865, 0.1353896507576004, 0.13534909955847768, 0.13530890077324167, 0.13526905125531047, 0.1352295478887108, 0.13519038758770774, 0.13515156729643937, 0.13511308398855892, 0.13507493466688233, 0.13503711636304225, 0.1349996261371476, 0.13496246107744925, 0.1349256183000107, 0.13488909494838502, 0.13485288819329633, 0.1348169952323271, 0.13478141328961019, 0.1347461396155259, 0.1347111714864043, 0.1346765062042316, 0.13464214109636177, 0.13460807351523252, 0.13457430083808555, 0.13454082046669139, 0.13450762982707826, 0.13447472636926538, 0.1344421075670001, 0.1344097709174989, 0.13437771394119274, 0.1343459341814757, 0.1343144292044577, 0.13428319659872076, 0.1342522339750784, 0.1342215389663395, 0.1341911092270747, 0.13416094243338644, 0.1341310362826823, 0.13410138849345168, 0.13407199680504514, 0.13404285897745766, 0.1340139727911139, 0.13398533604665733, 0.13395694656474152, 0.13392880218582504, 0.13390090076996808, 0.1338732401966332, 0.1338458183644872, 0.13381863319120707, 0.1337916826132873, 0.1337649645858507, 0.1337384770824609, 0.13371221809493786, 0.13368618563317541, 0.1336603777249612, 0.13363479241579912, 0.1336094277687338, 0.13358428186417728, 0.13355935279973807, 0.13353463869005203, 0.13351013766661574, 0.1334858478776214, 0.1334617674877945, 0.13343789467823247, 0.13341422764624633, 0.13339076460520344, 0.1333675037843725, 0.13334444342877066, 0.13332158179901155, 0.13329891717115624, 0.13327644783656503, 0.13325417210175164, 0.13323208828823854, 0.13321019473241433, 0.13318848978539244, 0.13316697181287218, 0.13314563919500016, 0.1331244903262345, 0.13310352361520966, 0.13308273748460353, 0.13306213037100528, 0.13304170072478536, 0.13302144700996643, 0.13300136770409596, 0.13298146129812038, 0.13296172629625996, 0.13294216121588615, 0.132922764587399, 0.1329035349541069, 0.13288447087210706, 0.13286557091016762, 0.13284683364961072, 0.1328282576841967, 0.13280984162001042, 0.1327915840753473, 0.13277348368060227, 0.13275553907815818, 0.13273774892227672, 0.13272011187898983, 0.13270262662599233, 0.1326852918525356, 0.13266810625932268, 0.13265106855840397, 0.1326341774730744, 0.13261743173777144, 0.13260083009797413, 0.13258437131010334, 0.13256805414142278, 0.13255187736994117, 0.13253583978431552, 0.1325199401837546, 0.1325041773779249, 0.13248855018685587, 0.1324730574408471, 0.13245769798037618, 0.13244247065600748, 0.13242737432830168, 0.13241240786772626, 0.13239757015456718, 0.13238286007884084, 0.1323682765402074, 0.13235381844788482, 0.1323394847205633, 0.13232527428632143, 0.13231118608254225, 0.1322972190558306, 0.13228337216193134, 0.13226964436564798, 0.13225603464076227, 0.1322425419699548, 0.13222916534472598, 0.13221590376531783, 0.13220275624063701, 0.13218972178817762, 0.13217679943394567, 0.13216398821238384, 0.13215128716629693, 0.13213869534677797, 0.13212621181313536, 0.1321138356328202, 0.13210156588135483, 0.13208940164226146, 0.132077342006992, 0.1320653860748583, 0.1320535329529629, 0.1320417817561306, 0.1320301316068407, 0.13201858163516003, 0.1320071309786758, 0.13199577878243007, 0.13198452419885442, 0.1319733663877049, 0.131962304515998, 0.13195133775794732, 0.1319404652949001, 0.1319296863152752, 0.1319190000145011, 0.1319084055949546, 0.1318979022659001, 0.13188748924342955, 0.1318771657504026, 0.1318669310163877, 0.13185678427760358, 0.13184672477686088, 0.13183675176350523, 0.13182686449335979, 0.13181706222866899, 0.1318073442380425, 0.13179770979639993, 0.1317881581849157, 0.1317786886909647, 0.13176930060806835, 0.13175999323584098, 0.1317507658799369, 0.13174161785199792, 0.13173254846960125, 0.13172355705620795, 0.13171464294111176, 0.13170580545938826, 0.13169704395184512, 0.1316883577649718, 0.13167974625089038, 0.13167120876730698, 0.13166274467746283, 0.13165435335008663, 0.1316460341593466, 0.13163778648480368, 0.13162960971136442, 0.13162150322923485, 0.13161346643387453, 0.13160549872595084, 0.13159759951129413, 0.13158976820085277, 0.13158200421064908, 0.13157430696173483, 0.13156667588014856, 0.13155911039687143, 0.1315516099477853, 0.13154417397362975, 0.13153680191996023, 0.13152949323710658, 0.1315222473801313, 0.13151506380878897, 0.13150794198748567, 0.13150088138523836, 0.1314938814756356, 0.13148694173679754, 0.131480061651337, 0.13147324070632052, 0.1314664783932301, 0.1314597742079246, 0.1314531276506024, 0.13144653822576377, 0.13144000544217335, 0.13143352881282383, 0.13142710785489936, 0.13142074208973897, 0.131414431042801, 0.13140817424362766, 0.13140197122580938, 0.13139582152695017, 0.131389724688633, 0.13138368025638517, 0.13137768777964465, 0.13137174681172598, 0.13136585690978717, 0.1313600176347961, 0.13135422855149823, 0.1313484892283835, 0.13134279923765432, 0.1313371581551934, 0.13133156556053213, 0.13132602103681912, 0.13132052417078882, 0.13131507455273095, 0.13130967177645936, 0.13130431543928223, 0.13129900514197135, 0.13129374048873282, 0.13128852108717703, 0.1312833465482897, 0.13127821648640217, 0.13127313051916345, 0.13126808826751082, 0.13126308935564196, 0.1312581334109867, 0.13125322006417925, 0.13124834894903042, 0.13124351970250062, 0.13123873196467223, 0.13123398537872316, 0.1312292795908998, 0.13122461425049098, 0.13121998900980153, 0.13121540352412622, 0.13121085745172428, 0.13120635045379372, 0.13120188219444595, 0.13119745234068078, 0.1311930605623617, 0.1311887065321909, 0.1311843899256851, 0.1311801104211511, 0.13117586769966202, 0.1311716614450333, 0.13116749134379915, 0.13116335708518903, 0.13115925836110465, 0.13115519486609695, 0.13115116629734297, 0.13114717235462375, 0.13114321274030158, 0.13113928715929773, 0.1311353953190708, 0.1311315369295945, 0.13112771170333615, 0.1311239193552353, 0.13112015960268236, 0.13111643216549748, 0.1311127367659097, 0.1311090731285363, 0.13110544098036211, 0.13110184005071918, 0.13109827007126654, 0.13109473077597045, 0.1310912219010841, 0.13108774318512836, 0.13108429436887195, 0.13108087519531228, 0.13107748540965625, 0.13107412475930125, 0.1310707929938162, 0.131067489864923, 0.1310642151264781, 0.13106096853445376, 0.1310577498469203, 0.1310545588240278, 0.13105139522798825, 0.1310482588230578, 0.13104514937551928, 0.13104206665366452, 0.13103901042777755, 0.13103598047011686, 0.13103297655489882, 0.13102999845828087, 0.1310270459583445, 0.13102411883507903, 0.13102121687036491, 0.13101833984795783, 0.1310154875534722, 0.13101265977436552, 0.13100985629992212, 0.13100707692123786, 0.1310043214312044, 0.13100158962449351, 0.13099888129754225, 0.1309961962485374, 0.13099353427740043, 0.13099089518577295, 0.13098827877700148, 0.13098568485612289, 0.1309831132298502, 0.13098056370655764, 0.13097803609626688, 0.13097553021063246, 0.13097304586292796, 0.130970582868032, 0.13096814104241458, 0.13096572020412311, 0.13096332017276915, 0.13096094076951487, 0.13095858181705952, 0.13095624313962653, 0.13095392456295019, 0.13095162591426274, 0.13094934702228142, 0.13094708771719588, 0.13094484783065521, 0.1309426271957558, 0.13094042564702846, 0.13093824302042653, 0.1309360791533131, 0.1309339338844496, 0.13093180705398308, 0.1309296985034348, 0.13092760807568815, 0.13092553561497697, 0.1309234809668741, 0.13092144397827965, 0.13091942449740973, 0.1309174223737851, 0.1309154374582199, 0.13091346960281078, 0.13091151866092543, 0.13090958448719198, 0.13090766693748815, 0.13090576586893035, 0.13090388113986332, 0.1309020126098491, 0.130900160139657, 0.1308983235912531, 0.13089650282778978, 0.13089469771359577, 0.13089290811416598, 0.13089113389615128, 0.13088937492734887, 0.13088763107669207, 0.130885902214241, 0.13088418821117243, 0.13088248893977056, 0.13088080427341722, 0.13087913408658264, 0.1308774782548159, 0.1308758366547358, 0.13087420916402173, 0.1308725956614043, 0.13087099602665653, 0.13086941014058484, 0.13086783788502007, 0.13086627914280877, 0.13086473379780447, 0.1308632017348589, 0.13086168283981364, 0.13086017699949143, 0.13085868410168772, 0.13085720403516238, 0.13085573668963163, 0.13085428195575916, 0.13085283972514883, 0.13085140989033592, 0.1308499923447795, 0.13084858698285434, 0.13084719369984307, 0.13084581239192838, 0.13084444295618522, 0.13084308529057329, 0.1308417392939292, 0.13084040486595921, 0.13083908190723154, 0.13083777031916902, 0.13083647000404175, 0.13083518086495988, 0.1308339028058663, 0.1308326357315295, 0.13083137954753649, 0.1308301341602858, 0.13082889947698037, 0.1308276754056209, 0.13082646185499866, 0.13082525873468898, 0.13082406595504426, 0.13082288342718762, 0.1308217110630058, 0.13082054877514324, 0.13081939647699484, 0.13081825408270004, 0.13081712150713629, 0.13081599866591254, 0.13081488547536319, 0.13081378185254178, 0.1308126877152146, 0.13081160298185487, 0.1308105275716365, 0.13080946140442812, 0.13080840440078706, 0.1308073564819534, 0.1308063175698443, 0.13080528758704793, 0.13080426645681778, 0.13080325410306715, 0.13080225045036314, 0.13080125542392118, 0.13080026894959965, 0.1307992909538939, 0.13079832136393135, 0.13079736010746548, 0.13079640711287094, 0.1307954623091379, 0.13079452562586683, 0.13079359699326334, 0.13079267634213287, 0.13079176360387565, 0.1307908587104814, 0.13078996159452447, 0.1307890721891588, 0.1307881904281127, 0.13078731624568427, 0.13078644957673607, 0.13078559035669074, 0.13078473852152586, 0.1307838940077693, 0.1307830567524944, 0.13078222669331543, 0.13078140376838285, 0.13078058791637864, 0.13077977907651198, 0.13077897718851425, 0.13077818219263512, 0.1307773940296377, 0.13077661264079418, 0.13077583796788161, 0.13077506995317736, 0.130774308539455, 0.13077355366997984, 0.13077280528850505, 0.13077206333926703, 0.13077132776698147, 0.1307705985168394, 0.13076987553450267, 0.1307691587661004, 0.13076844815822458, 0.13076774365792623, 0.13076704521271165, 0.13076635277053805, 0.1307656662798101, 0.13076498568937595, 0.13076431094852323, 0.13076364200697563, 0.1307629788148889, 0.13076232132284715, 0.1307616694818592, 0.13076102324335503, 0.13076038255918201, 0.1307597473816014, 0.13075911766328474, 0.1307584933573104, 0.13075787441716, 0.13075726079671504, 0.13075665245025325, 0.13075604933244553, 0.13075545139835226, 0.13075485860342015, 0.13075427090347866, 0.13075368825473718, 0.13075311061378125, 0.13075253793756964, 0.13075197018343104, 0.13075140730906076, 0.13075084927251807, 0.1307502960322223, 0.13074974754695046, 0.13074920377583368, 0.13074866467835453, 0.13074813021434364, 0.130747600343977, 0.13074707502777286, 0.13074655422658876, 0.1307460379016188, 0.1307455260143904, 0.13074501852676193, 0.13074451540091928, 0.1307440165993736, 0.13074352208495807, 0.13074303182082542, 0.130742545770445, 0.1307420638976003, 0.1307415861663858, 0.13074111254120485, 0.13074064298676666, 0.1307401774680837, 0.13073971595046924, 0.1307392583995346, 0.13073880478118677, 0.13073835506162565, 0.13073790920734168, 0.13073746718511345, 0.13073702896200484, 0.13073659450536299, 0.13073616378281558, 0.13073573676226868, 0.130735313411904, 0.13073489370017688, 0.13073447759581372, 0.1307340650678097, 0.13073365608542659, 0.13073325061819008, 0.1307328486358882, 0.13073245010856818, 0.13073205500653512, 0.13073166330034908, 0.1307312749608232, 0.13073088995902146, 0.1307305082662567, 0.13073012985408813, 0.1307297546943194, 0.13072938275899676, 0.1307290140204064, 0.13072864845107288, 0.13072828602375694, 0.13072792671145325, 0.1307275704873888, 0.1307272173250206, 0.1307268671980337, 0.13072652008033953, 0.13072617594607358, 0.13072583476959368, 0.13072549652547816, 0.1307251611885236, 0.1307248287337435, 0.13072449913636588, 0.13072417237183181, 0.13072384841579335, 0.13072352724411188, 0.1307232088328562, 0.130722893158301, 0.13072258019692454, 0.13072226992540745, 0.1307219623206308, 0.13072165735967436, 0.13072135501981477, 0.13072105527852415, 0.1307207581134681, 0.13072046350250424, 0.13072017142368056, 0.13071988185523364, 0.13071959477558712, 0.1307193101633501, 0.13071902799731558, 0.13071874825645874, 0.1307184709199356, 0.13071819596708112, 0.1307179233774081, 0.13071765313060527, 0.130717385206536, 0.1307171195852367, 0.13071685624691548, 0.13071659517195028, 0.13071633634088803, 0.13071607973444263, 0.13071582533349385, 0.1307155731190857, 0.13071532307242514, 0.1307150751748808, 0.13071482940798118, 0.13071458575341377, 0.1307143441930234, 0.13071410470881087, 0.13071386728293166, 0.13071363189769475, 0.13071339853556113, 0.13071316717914244, 0.13071293781119986, 0.13071271041464275, 0.1307124849725273, 0.1307122614680553, 0.13071203988457306, 0.13071182020556996, 0.13071160241467716, 0.13071138649566671, 0.13071117243244995, 0.1307109602090767, 0.13071074980973368, 0.13071054121874362, 0.13071033442056404, 0.13071012939978588, 0.1307099261411327, 0.13070972462945923, 0.13070952484975046, 0.1307093267871204, 0.13070913042681098, 0.130708935754191, 0.13070874275475497, 0.1307085514141222, 0.13070836171803538, 0.13070817365235993, 0.1307079872030827, 0.13070780235631096, 0.1307076190982714, 0.1307074374153091, 0.13070725729388652, 0.13070707872058232, 0.1307069016820908, 0.13070672616522036, 0.13070655215689286, 0.13070637964414267, 0.13070620861411547, 0.1307060390540674, 0.13070587095136435, 0.1307057042934805, 0.1307055390679979, 0.13070537526260517, 0.13070521286509698, 0.13070505186337264, 0.13070489224543566, 0.1307047339993925, 0.13070457711345204, 0.13070442157592427, 0.13070426737521984, 0.1307041144998489, 0.1307039629384204, 0.13070381267964126, 0.13070366371231526, 0.13070351602534264, 0.13070336960771892]
实验要求3 绘制误差曲线
plt.plot(range(iter_num2),cost_lst2,"r-+") plt.xlabel("迭代次数") plt.ylabel("误差") plt.show()
1.4 最小二乘法求参数
最小二乘法的需要求解最优参数w∗:
梯度下降与最小二乘法的比较:
梯度下降:需要选择学习率α \alphaα,需要多次迭代,当特征数量n nn大时也能较好适用,适用于各种类型的模型
最小二乘法:不需要选择学习率α,一次计算得出,需要计算(XTX)−1,如果特征数量n nn较大则运算代价大,因为矩阵逆的计算时间复杂度为O(n3),通常来说当n nn小于10000 时还是可以接受的,只适用于线性模型,不适合逻辑回归模型等其他模型
def lsm(X,y): w=np.linalg.inv(X.T@X)@X.T@y return w
def lsm_v(X,y): w=np.linalg.inv(np.dot(X.T,X)) w=np.dot(w,X.T) w=np.dot(w,y) return w
lsm(X,y)
array([[-3.89578088], [ 1.19303364]])
lsm_v(X,y)
array([[-3.89578088], [ 1.19303364]])
1.5 来点正则化?
1.5.1 普通的线性回归
from sklearn import linear_model reg=linear_model.LinearRegression() reg.fit(X,y)
LinearRegression()
#回到单变量的线性回归中来 x=X y_1=reg.predict(x) plt.plot(x,y_1,"r-+",label="预测线") plt.scatter(data["人口"],data["收益"], label='训练数据') plt.xlim(4.7,10) plt.xlabel("人口",fontsize=10) plt.ylabel("收益",fontsize=10) plt.title("人口与收益之间的关系") plt.show()
reg.coef_,reg.intercept_,reg.score(X,y)
(array([[0. , 1.19303364]]), array([-3.89578088]), 0.7020315537841397)
1.5.2 岭回归
from sklearn import linear_model reg_rigde=linear_model.Ridge() reg_rigde.fit(X,y)
Ridge()
#回到单变量的线性回归中来,Ridge x=X y_1=reg_rigde.predict(x) plt.plot(x,y_1,"r-+",label="预测线") plt.scatter(data["人口"],data["收益"], label='训练数据') plt.xlim(4.7,10) plt.xlabel("人口",fontsize=10) plt.ylabel("收益",fontsize=10) plt.title("人口与收益之间的关系") plt.show()
reg_rigde.coef_,reg_rigde.intercept_,reg_rigde.score(X,y)
(array([[0. , 1.1922044]]), array([-3.88901439]), 0.7020312146131912)
1.5.3 Lasso回归
from sklearn import linear_model reg_lasso=linear_model.Lasso() reg_lasso.fit(X,y)
Lasso()
#回到单变量的线性回归中来,Lasso x=X y_1=reg_lasso.predict(x) plt.plot(x,y_1,"r-+",label="预测线") plt.scatter(data["人口"],data["收益"], label='训练数据') plt.xlim(4.7,10) plt.xlabel("人口",fontsize=10) plt.ylabel("收益",fontsize=10) plt.title("人口与收益之间的关系") plt.show()
reg_lasso.coef_,reg_lasso.intercept_,reg_lasso.score(X,y)
(array([0. , 1.12556458]), array([-3.34524677]), 0.6997863246152711)
实验要求4 手写代码实现单变量的L2正则化
#超参数为I,学习率alpha,对所有样本 def gradient_descent_l2(X,y,w,iter_num,alpha,lambd): temp=np.zeros((col_num-1,1)) cost_lst=[] for i in range(iter_num): error=h(X,w)-y for j in range(col_num-1): incre=np.multiply(error.ravel(),X[:,j].ravel()) temp[j,0]=w[j,0]-((alpha/m)*(np.sum(incre)+2*lambd*w[j,0])) w=temp cost_lst.append(cost(X,y,w)) return w,cost_lst
iter_num=200 alpha=0.001 lambd=2 w=np.zeros((col_num-1,1)) w,cost_lst=gradient_descent_l2(X,y,w,iter_num,alpha,lambd)
plt.plot(range(iter_num),cost_lst,"r-+") plt.xlabel("迭代次数") plt.ylabel("误差") plt.show()