【Pytorch(二)】Numpy 搭建全连接神经网络(3)

简介: 【Pytorch(二)】Numpy 搭建全连接神经网络(3)

11. 数据集准备

现在是时候尝试应用我们的模型来解决一个简单的分类问题。为了检测模型是否能够顺利训练,下面我们将生成一个含有两个类的点集(如下图所示,两个类别的点分别用不同颜色表示),然后尝试训练模型来对这些点进行分类(二元分类问题)。

image.png



# number of samples in the data set
N_SAMPLES = 1000
# ratio between training and test sets
TEST_SIZE = 0.1

# we will use sklearn.make_moons() to generate the dataset:
#  - n_samples: 生成样本数量
#  - noise: 高斯噪声
#  - random_state: 生成随机种子,给定一个int型数据,能够保证每次生成数据相同
X, y = make_moons(n_samples = N_SAMPLES, noise=0.2, random_state=100)
# split the dataset into training set (90%) & test set (10%)
#  - test_size: 如果是浮点数,则应该在0.0和1.0之间,表示要测试集占总数据集的比例;如果是int类型,表示测试集的绝对数量。
#  - random_state: 随机数生成器使用的种子
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SIZE, random_state=42)

下面我们来观察一下刚刚生成好的数据集。


shape of X:  (1000, 2)
X =  [[ 2.24907069e-05  1.07275825e+00]
 [-5.59037377e-02  4.25241282e-01]
 [ 2.40944879e-02  4.08065802e-01]
 ...
 [ 1.75841594e+00 -5.77404262e-01]
 [ 1.26710180e+00 -4.42980152e-01]
 [-1.75927072e-01  5.83509936e-01]]
shape of X:  (900, 2)
# check the information in Y
print("shape of Y: ", y.shape)  # (1000,)
print("number of zeros in Y: ", np.sum(y==0))  # 500
print("number of ones in Y: ", np.sum(y==1))  # 500
print("Y = ", y)
shape of Y:  (1000,)
number of zeros in Y:  500
number of ones in Y:  500
Y =  [0 0 1 0 1 0 0 1 1 1 1 1 0 1 1 0 0 1 0 1 1 0 1 0 0 1 1 0 1 1 1 0 0 0 0 1 1
 1 0 0 0 0 0 0 1 0 1 0 1 0 0 1 0 0 0 0 1 1 1 0 0 1 1 0 1 0 0 0 0 0 1 1 1 0
 1 0 0 1 0 1 1 1 1 0 1 0 1 1 0 1 0 0 1 0 0 1 0 0 0 0 1 1 0 1 1 0 0 0 1 0 0
 0 0 1 1 0 0 0 1 0 1 1 0 0 1 1 0 0 1 1 0 0 1 0 0 0 1 0 0 1 1 0 0 0 1 0 0 0
 0 1 0 1 1 1 1 0 1 1 1 0 1 1 0 0 1 0 1 0 1 1 0 0 1 0 1 0 0 1 1 1 1 1 0 1 1
 0 0 0 0 0 1 0 1 0 1 1 1 0 0 1 1 0 0 1 1 1 0 0 1 1 0 0 0 0 0 0 0 0 0 1 0 0
 1 1 0 1 0 0 1 1 0 1 0 0 0 0 1 1 1 1 0 0 0 0 0 1 0 0 1 1 1 0 1 1 1 1 1 1 1
 0 0 0 0 0 0 1 1 1 1 0 1 0 0 0 1 1 1 0 0 1 0 1 1 1 0 0 1 0 0 0 1 1 0 1 0 0
 1 1 0 0 0 0 0 0 1 0 0 0 1 0 1 0 0 1 0 1 1 1 0 0 1 0 1 0 1 0 1 0 0 1 0 1 1
 0 0 1 0 1 1 0 0 1 1 1 1 0 0 1 1 0 1 1 0 1 1 0 1 1 0 0 1 0 0 1 0 0 0 1 1 0
 0 1 1 1 1 0 0 0 0 1 0 0 0 0 1 1 0 1 1 1 0 0 1 0 1 0 1 1 0 0 1 0 1 1 1 1 0
 1 0 0 0 0 1 0 0 0 0 0 0 1 1 1 1 0 1 0 1 1 1 1 0 1 1 1 1 0 1 0 0 1 1 1 1 0
 1 1 1 0 1 1 1 1 1 0 1 1 1 0 0 1 0 1 0 0 0 1 1 1 0 0 0 0 1 1 1 0 1 0 1 0 0
 1 0 1 0 1 1 0 1 0 0 0 1 1 1 1 0 1 1 0 0 1 1 0 0 0 1 0 0 0 0 1 0 1 1 0 1 1
 1 0 0 1 0 0 0 0 0 1 1 1 1 0 1 0 1 1 0 0 0 0 0 1 0 0 0 0 1 1 1 1 0 1 1 0 0
 0 1 1 0 1 1 0 1 0 1 1 0 1 0 0 0 0 1 1 0 0 0 1 1 1 0 1 0 1 1 1 1 0 1 1 1 1
 0 1 1 0 0 1 1 1 1 1 0 0 0 0 0 1 1 0 0 1 0 1 1 1 1 0 0 1 0 0 0 0 0 0 0 0 0
 1 0 0 1 0 1 1 0 1 1 0 1 1 0 1 0 1 0 1 0 1 0 1 0 1 1 1 0 0 1 1 1 0 1 1 0 1
 0 1 0 1 1 0 1 1 1 0 1 1 0 1 0 1 1 1 0 1 1 0 1 0 1 0 0 0 0 1 1 1 0 1 1 1 0
 1 1 0 0 1 0 0 1 1 1 0 1 1 0 0 0 1 1 0 1 0 0 1 0 0 1 0 0 1 1 1 1 1 1 1 0 1
 0 0 1 1 1 1 0 0 1 0 0 0 0 1 1 1 0 1 0 0 0 0 1 0 1 1 1 1 1 1 0 0 1 1 0 1 0
 1 0 1 1 0 0 1 1 1 0 1 1 0 1 0 0 1 0 1 1 1 1 1 1 0 0 1 0 0 0 0 0 1 0 1 0 0
 0 0 0 1 1 1 1 1 0 0 0 1 1 0 1 0 1 0 1 0 1 0 0 0 1 0 1 0 1 0 0 0 1 1 0 0 1
 0 1 1 1 0 0 1 0 1 0 0 0 1 0 1 1 0 0 0 1 0 0 1 0 1 1 0 0 0 0 1 0 1 1 1 0 1
 0 1 0 0 0 1 1 0 1 0 0 1 1 0 1 0 0 1 0 1 1 0 0 0 0 0 0 1 1 0 0 0 0 1 0 0 1
 1 1 0 1 0 0 0 0 1 1 1 0 1 0 0 1 0 0 0 0 0 1 1 0 0 1 1 1 0 1 0 0 0 1 1 1 0
 1 1 1 0 1 0 1 1 1 0 0 0 1 1 1 1 0 1 0 1 0 1 1 1 1 1 0 1 1 0 1 1 0 1 1 1 1
 1]
# the function making up the graph of a dataset
def make_plot(X, y, plot_name, file_name=None, XX=None, YY=None, preds=None, dark=False):
    if (dark):
        plt.style.use('dark_background')
    else:
        sns.set_style("whitegrid")
    plt.figure(figsize=(16,12))
    axes = plt.gca()
    axes.set(xlabel="$X_1$", ylabel="$X_2$")
    plt.title(plot_name, fontsize=30)
    plt.subplots_adjust(left=0.20)
    plt.subplots_adjust(right=0.80)
    if(XX is not None and YY is not None and preds is not None):
        plt.contourf(XX, YY, preds.reshape(XX.shape), 25, alpha = 1, cmap=cm.Spectral)
        plt.contour(XX, YY, preds.reshape(XX.shape), levels=[.5], cmap="Greys", vmin=0, vmax=.6)
    plt.scatter(X[:, 0], X[:, 1], c=y.ravel(), s=40, cmap=plt.cm.Spectral, edgecolors='black')
    if(file_name):
        plt.savefig(file_name)
        plt.close()
make_plot(X, y, "Dataset")

image.png


请同学们依次尝试将 sklearn.make_moons() 函数中 noise 取 0, 0.2, 0.4, 0.8, 1.0 等不同值,查看并截图保存对应的数据集图像和训练结果,分析其趋势,及出现这样趋势的原因。


然后,请同学们将 sklearn.make_moons() 中 noise 固定为 0.8,将sklearn.model_selection.train_test_split() 函数中的 TEST_SIZE 改为 0.98,观察和原先 noise 为 0.2、TEST_SIZE 为 0.1 的情况相比,训练集上的准确度和测试集上的准确度都分别发生了什么变化?出现这样变化的原因是什么?


调整 make_moons() 函数中的 noise 和 train_test_split() 函数中的 TEST_SIZE,改变数据集中点的分布。 将 sklearn.make_moons() 中 noise 固定为 0.8,将sklearn.model_selection.train_test_split() 函数中的 TEST_SIZE 改为 0.98,观察和原先 noise 为 0.2、TEST_SIZE 为 0.1 的情况相比,训练过程折线图表示如图所示,测试结果如图所示。

image.png


结论:当测试集占总数据量10%,noise等于0.2时,训练速度较慢,但是Acc和Cost曲线比较平滑,且测试效果和模型最终在训练集上得到的效果相差不大,即可以学习出一种较好的模型。而当测试集占总数据量98%,noise等于0.8时,训练速度较快,但是Acc曲线出现毛刺,且测试效果和模型最终在训练集上得到的效果很大,即出现过拟合现象。分析可知,当训练集占比过小,模型只能学习到很少的知识,无法得到好的模型。


12. 模型训练及测试

下面我们来调用 train 函数对模型进行训练。


# let's train the neural network
params_values = train(X=np.transpose(X_train), Y=np.transpose(y_train.reshape((y_train.shape[0], 1))), 
                      nn_architecture=NN_ARCHITECTURE, epochs=10000, learning_rate=0.01)

Iteration: 00000 - cost: 0.69365 - accuracy: 0.50444
X.shape:  (2, 900)
Y_hat.shape:  (1, 900)
Y.shape:  (1, 900)
Iteration: 00050 - cost: 0.69349 - accuracy: 0.50444
Iteration: 00100 - cost: 0.69334 - accuracy: 0.50444
Iteration: 00150 - cost: 0.69319 - accuracy: 0.50444
Iteration: 00200 - cost: 0.69307 - accuracy: 0.50444
Iteration: 00250 - cost: 0.69295 - accuracy: 0.50444
Iteration: 00300 - cost: 0.69284 - accuracy: 0.50444
Iteration: 00350 - cost: 0.69272 - accuracy: 0.50444
Iteration: 00400 - cost: 0.69260 - accuracy: 0.50444
Iteration: 00450 - cost: 0.69249 - accuracy: 0.50444
Iteration: 00500 - cost: 0.69238 - accuracy: 0.50444
Iteration: 00550 - cost: 0.69228 - accuracy: 0.50444
Iteration: 00600 - cost: 0.69217 - accuracy: 0.50444
Iteration: 00650 - cost: 0.69206 - accuracy: 0.50444
Iteration: 00700 - cost: 0.69194 - accuracy: 0.50444
Iteration: 00750 - cost: 0.69182 - accuracy: 0.50444
Iteration: 00800 - cost: 0.69170 - accuracy: 0.50444
Iteration: 00850 - cost: 0.69156 - accuracy: 0.50444
Iteration: 00900 - cost: 0.69142 - accuracy: 0.50444
Iteration: 00950 - cost: 0.69126 - accuracy: 0.50444
Iteration: 01000 - cost: 0.69109 - accuracy: 0.50444
Iteration: 01050 - cost: 0.69090 - accuracy: 0.50444
Iteration: 01100 - cost: 0.69070 - accuracy: 0.50444
Iteration: 01150 - cost: 0.69049 - accuracy: 0.50444
Iteration: 01200 - cost: 0.69025 - accuracy: 0.50444
Iteration: 01250 - cost: 0.69000 - accuracy: 0.50889
Iteration: 01300 - cost: 0.68972 - accuracy: 0.52889
Iteration: 01350 - cost: 0.68941 - accuracy: 0.59778
Iteration: 01400 - cost: 0.68907 - accuracy: 0.66667
Iteration: 01450 - cost: 0.68869 - accuracy: 0.72222
Iteration: 01500 - cost: 0.68827 - accuracy: 0.76111
Iteration: 01550 - cost: 0.68780 - accuracy: 0.79111
Iteration: 01600 - cost: 0.68726 - accuracy: 0.81333
Iteration: 01650 - cost: 0.68666 - accuracy: 0.82778
Iteration: 01700 - cost: 0.68596 - accuracy: 0.83778
Iteration: 01750 - cost: 0.68512 - accuracy: 0.84000
Iteration: 01800 - cost: 0.68416 - accuracy: 0.84222
Iteration: 01850 - cost: 0.68308 - accuracy: 0.84444
Iteration: 01900 - cost: 0.68185 - accuracy: 0.84333
Iteration: 01950 - cost: 0.68042 - accuracy: 0.84222
Iteration: 02000 - cost: 0.67875 - accuracy: 0.84222
Iteration: 02050 - cost: 0.67680 - accuracy: 0.84444
Iteration: 02100 - cost: 0.67453 - accuracy: 0.84444
Iteration: 02150 - cost: 0.67183 - accuracy: 0.84667
Iteration: 02200 - cost: 0.66859 - accuracy: 0.84556
Iteration: 02250 - cost: 0.66471 - accuracy: 0.84222
Iteration: 02300 - cost: 0.66004 - accuracy: 0.84333
Iteration: 02350 - cost: 0.65437 - accuracy: 0.84222
Iteration: 02400 - cost: 0.64757 - accuracy: 0.84444
Iteration: 02450 - cost: 0.63942 - accuracy: 0.84778
Iteration: 02500 - cost: 0.62966 - accuracy: 0.84444
Iteration: 02550 - cost: 0.61796 - accuracy: 0.84111
Iteration: 02600 - cost: 0.60398 - accuracy: 0.84111
Iteration: 02650 - cost: 0.58764 - accuracy: 0.84111
Iteration: 02700 - cost: 0.56876 - accuracy: 0.84222
Iteration: 02750 - cost: 0.54730 - accuracy: 0.84222
Iteration: 02800 - cost: 0.52368 - accuracy: 0.85000
Iteration: 02850 - cost: 0.49867 - accuracy: 0.85333
Iteration: 02900 - cost: 0.47325 - accuracy: 0.85556
Iteration: 02950 - cost: 0.44840 - accuracy: 0.85556
Iteration: 03000 - cost: 0.42476 - accuracy: 0.85889
Iteration: 03050 - cost: 0.40263 - accuracy: 0.86333
Iteration: 03100 - cost: 0.38221 - accuracy: 0.86222
Iteration: 03150 - cost: 0.36367 - accuracy: 0.86778
Iteration: 03200 - cost: 0.34730 - accuracy: 0.87111
Iteration: 03250 - cost: 0.33327 - accuracy: 0.87444
Iteration: 03300 - cost: 0.32149 - accuracy: 0.87778
Iteration: 03350 - cost: 0.31175 - accuracy: 0.87889
Iteration: 03400 - cost: 0.30379 - accuracy: 0.88000
Iteration: 03450 - cost: 0.29733 - accuracy: 0.88111
Iteration: 03500 - cost: 0.29209 - accuracy: 0.88111
Iteration: 03550 - cost: 0.28783 - accuracy: 0.88111
Iteration: 03600 - cost: 0.28431 - accuracy: 0.88222
Iteration: 03650 - cost: 0.28133 - accuracy: 0.88333
Iteration: 03700 - cost: 0.27875 - accuracy: 0.88333
Iteration: 03750 - cost: 0.27648 - accuracy: 0.88333
Iteration: 03800 - cost: 0.27445 - accuracy: 0.88333
Iteration: 03850 - cost: 0.27262 - accuracy: 0.88222
Iteration: 03900 - cost: 0.27090 - accuracy: 0.88111
Iteration: 03950 - cost: 0.26930 - accuracy: 0.88000
Iteration: 04000 - cost: 0.26780 - accuracy: 0.88000
Iteration: 04050 - cost: 0.26634 - accuracy: 0.88000
Iteration: 04100 - cost: 0.26495 - accuracy: 0.88000
Iteration: 04150 - cost: 0.26356 - accuracy: 0.88000
Iteration: 04200 - cost: 0.26215 - accuracy: 0.87889
Iteration: 04250 - cost: 0.26074 - accuracy: 0.88000
Iteration: 04300 - cost: 0.25933 - accuracy: 0.88222
Iteration: 04350 - cost: 0.25793 - accuracy: 0.88333
Iteration: 04400 - cost: 0.25652 - accuracy: 0.88444
Iteration: 04450 - cost: 0.25510 - accuracy: 0.88444
Iteration: 04500 - cost: 0.25369 - accuracy: 0.88444
Iteration: 04550 - cost: 0.25227 - accuracy: 0.88333
Iteration: 04600 - cost: 0.25087 - accuracy: 0.88444
Iteration: 04650 - cost: 0.24944 - accuracy: 0.88556
Iteration: 04700 - cost: 0.24798 - accuracy: 0.88556
Iteration: 04750 - cost: 0.24650 - accuracy: 0.88667
Iteration: 04800 - cost: 0.24497 - accuracy: 0.88778
Iteration: 04850 - cost: 0.24336 - accuracy: 0.88778
Iteration: 04900 - cost: 0.24171 - accuracy: 0.88889
Iteration: 04950 - cost: 0.23999 - accuracy: 0.89000
Iteration: 05000 - cost: 0.23821 - accuracy: 0.89000
Iteration: 05050 - cost: 0.23635 - accuracy: 0.89222
Iteration: 05100 - cost: 0.23441 - accuracy: 0.89333
Iteration: 05150 - cost: 0.23237 - accuracy: 0.89333
Iteration: 05200 - cost: 0.23021 - accuracy: 0.89444
Iteration: 05250 - cost: 0.22792 - accuracy: 0.89556
Iteration: 05300 - cost: 0.22550 - accuracy: 0.89667
Iteration: 05350 - cost: 0.22292 - accuracy: 0.89667
Iteration: 05400 - cost: 0.22018 - accuracy: 0.89778
Iteration: 05450 - cost: 0.21728 - accuracy: 0.90000
Iteration: 05500 - cost: 0.21418 - accuracy: 0.90222
Iteration: 05550 - cost: 0.21087 - accuracy: 0.90444
Iteration: 05600 - cost: 0.20736 - accuracy: 0.90556
Iteration: 05650 - cost: 0.20364 - accuracy: 0.91111
Iteration: 05700 - cost: 0.19973 - accuracy: 0.91333
Iteration: 05750 - cost: 0.19562 - accuracy: 0.91444
Iteration: 05800 - cost: 0.19133 - accuracy: 0.91889
Iteration: 05850 - cost: 0.18686 - accuracy: 0.92222
Iteration: 05900 - cost: 0.18224 - accuracy: 0.92556
Iteration: 05950 - cost: 0.17747 - accuracy: 0.92778
Iteration: 06000 - cost: 0.17260 - accuracy: 0.93000
Iteration: 06050 - cost: 0.16767 - accuracy: 0.93333
Iteration: 06100 - cost: 0.16269 - accuracy: 0.93444
Iteration: 06150 - cost: 0.15775 - accuracy: 0.93778
Iteration: 06200 - cost: 0.15289 - accuracy: 0.93778
Iteration: 06250 - cost: 0.14812 - accuracy: 0.93889
Iteration: 06300 - cost: 0.14350 - accuracy: 0.94333
Iteration: 06350 - cost: 0.13907 - accuracy: 0.94444
Iteration: 06400 - cost: 0.13485 - accuracy: 0.94444
Iteration: 06450 - cost: 0.13086 - accuracy: 0.94556
Iteration: 06500 - cost: 0.12711 - accuracy: 0.94667
Iteration: 06550 - cost: 0.12361 - accuracy: 0.95000
Iteration: 06600 - cost: 0.12035 - accuracy: 0.95444
Iteration: 06650 - cost: 0.11733 - accuracy: 0.95778
Iteration: 06700 - cost: 0.11456 - accuracy: 0.95778
Iteration: 06750 - cost: 0.11200 - accuracy: 0.95889
Iteration: 06800 - cost: 0.10963 - accuracy: 0.96000
Iteration: 06850 - cost: 0.10745 - accuracy: 0.96000
Iteration: 06900 - cost: 0.10544 - accuracy: 0.96222
Iteration: 06950 - cost: 0.10359 - accuracy: 0.96111
Iteration: 07000 - cost: 0.10188 - accuracy: 0.96111
Iteration: 07050 - cost: 0.10031 - accuracy: 0.96222
Iteration: 07100 - cost: 0.09885 - accuracy: 0.96222
Iteration: 07150 - cost: 0.09750 - accuracy: 0.96222
Iteration: 07200 - cost: 0.09623 - accuracy: 0.96222
Iteration: 07250 - cost: 0.09506 - accuracy: 0.96444
Iteration: 07300 - cost: 0.09399 - accuracy: 0.96556
Iteration: 07350 - cost: 0.09298 - accuracy: 0.96556
Iteration: 07400 - cost: 0.09203 - accuracy: 0.96667
Iteration: 07450 - cost: 0.09118 - accuracy: 0.96667
Iteration: 07500 - cost: 0.09041 - accuracy: 0.96667
Iteration: 07550 - cost: 0.08969 - accuracy: 0.96667
Iteration: 07600 - cost: 0.08898 - accuracy: 0.96667
Iteration: 07650 - cost: 0.08831 - accuracy: 0.96667
Iteration: 07700 - cost: 0.08767 - accuracy: 0.96667
Iteration: 07750 - cost: 0.08707 - accuracy: 0.96667
Iteration: 07800 - cost: 0.08647 - accuracy: 0.96778
Iteration: 07850 - cost: 0.08594 - accuracy: 0.96667
Iteration: 07900 - cost: 0.08544 - accuracy: 0.96667
Iteration: 07950 - cost: 0.08497 - accuracy: 0.96667
Iteration: 08000 - cost: 0.08453 - accuracy: 0.96556
Iteration: 08050 - cost: 0.08412 - accuracy: 0.96667
Iteration: 08100 - cost: 0.08371 - accuracy: 0.96667
Iteration: 08150 - cost: 0.08332 - accuracy: 0.96889
Iteration: 08200 - cost: 0.08295 - accuracy: 0.96889
Iteration: 08250 - cost: 0.08259 - accuracy: 0.96889
Iteration: 08300 - cost: 0.08219 - accuracy: 0.96889
Iteration: 08350 - cost: 0.08180 - accuracy: 0.96778
Iteration: 08400 - cost: 0.08145 - accuracy: 0.96778
Iteration: 08450 - cost: 0.08114 - accuracy: 0.96778
Iteration: 08500 - cost: 0.08084 - accuracy: 0.96889
Iteration: 08550 - cost: 0.08055 - accuracy: 0.96889
Iteration: 08600 - cost: 0.08025 - accuracy: 0.97000
Iteration: 08650 - cost: 0.07996 - accuracy: 0.97000
Iteration: 08700 - cost: 0.07968 - accuracy: 0.97000
Iteration: 08750 - cost: 0.07939 - accuracy: 0.97000
Iteration: 08800 - cost: 0.07912 - accuracy: 0.96889
Iteration: 08850 - cost: 0.07885 - accuracy: 0.96889
Iteration: 08900 - cost: 0.07860 - accuracy: 0.96889
Iteration: 08950 - cost: 0.07836 - accuracy: 0.96889
Iteration: 09000 - cost: 0.07812 - accuracy: 0.96889
Iteration: 09050 - cost: 0.07788 - accuracy: 0.96889
Iteration: 09100 - cost: 0.07765 - accuracy: 0.96889
Iteration: 09150 - cost: 0.07743 - accuracy: 0.96889
Iteration: 09200 - cost: 0.07721 - accuracy: 0.96889
Iteration: 09250 - cost: 0.07698 - accuracy: 0.96889
Iteration: 09300 - cost: 0.07676 - accuracy: 0.96889
Iteration: 09350 - cost: 0.07653 - accuracy: 0.96889
Iteration: 09400 - cost: 0.07631 - accuracy: 0.96889
Iteration: 09450 - cost: 0.07610 - accuracy: 0.96889
Iteration: 09500 - cost: 0.07588 - accuracy: 0.96889
Iteration: 09550 - cost: 0.07568 - accuracy: 0.96889
Iteration: 09600 - cost: 0.07550 - accuracy: 0.96889
Iteration: 09650 - cost: 0.07532 - accuracy: 0.96889
Iteration: 09700 - cost: 0.07516 - accuracy: 0.96889
Iteration: 09750 - cost: 0.07500 - accuracy: 0.96889
Iteration: 09800 - cost: 0.07485 - accuracy: 0.96889
Iteration: 09850 - cost: 0.07470 - accuracy: 0.96889
Iteration: 09900 - cost: 0.07456 - accuracy: 0.96889
Iteration: 09950 - cost: 0.07442 - accuracy: 0.96889

调用一次 full_forward_propagation(),在测试集上评估训练好的模型。

# prediction
Y_test_hat, _ = full_forward_propagation(np.transpose(X_test), params_values, NN_ARCHITECTURE)
# accuracy achieved on the test set
acc_test = get_accuracy_value(Y_test_hat, np.transpose(y_test.reshape((y_test.shape[0], 1))))
print("Test set accuracy: {:.2f}".format(acc_test))


Test set accuracy: 0.98

相关文章
|
1月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】32. 卷积神经网络之稠密连接网络(DenseNet)介绍及其Pytorch实现
【从零开始学习深度学习】32. 卷积神经网络之稠密连接网络(DenseNet)介绍及其Pytorch实现
|
1月前
|
机器学习/深度学习 自然语言处理 算法
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
|
1月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】31. 卷积神经网络之残差网络(ResNet)介绍及其Pytorch实现
【从零开始学习深度学习】31. 卷积神经网络之残差网络(ResNet)介绍及其Pytorch实现
|
1月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】36. 门控循环神经网络之长短期记忆网络(LSTM)介绍、Pytorch实现LSTM并进行训练预测
【从零开始学习深度学习】36. 门控循环神经网络之长短期记忆网络(LSTM)介绍、Pytorch实现LSTM并进行训练预测
|
18天前
|
并行计算 PyTorch 程序员
老程序员分享:Pytorch入门之Siamese网络
老程序员分享:Pytorch入门之Siamese网络
16 0
|
1月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】50.Pytorch_NLP项目实战:卷积神经网络textCNN在文本情感分类的运用
【从零开始学习深度学习】50.Pytorch_NLP项目实战:卷积神经网络textCNN在文本情感分类的运用
|
1月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】35. 门控循环神经网络之门控循环单元(gated recurrent unit,GRU)介绍、Pytorch实现GRU并进行训练预测
【从零开始学习深度学习】35. 门控循环神经网络之门控循环单元(gated recurrent unit,GRU)介绍、Pytorch实现GRU并进行训练预测
|
1月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】26.卷积神经网络之AlexNet模型介绍及其Pytorch实现【含完整代码】
【从零开始学习深度学习】26.卷积神经网络之AlexNet模型介绍及其Pytorch实现【含完整代码】
|
1月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】28.卷积神经网络之NiN模型介绍及其Pytorch实现【含完整代码】
【从零开始学习深度学习】28.卷积神经网络之NiN模型介绍及其Pytorch实现【含完整代码】
|
13天前
|
机器学习/深度学习 编解码 数据可视化
图神经网络版本的Kolmogorov Arnold(KAN)代码实现和效果对比
目前我们看到有很多使用KAN替代MLP的实验,但是目前来说对于图神经网络来说还没有类似的实验,今天我们就来使用KAN创建一个图神经网络Graph Kolmogorov Arnold(GKAN),来测试下KAN是否可以在图神经网络方面有所作为。
35 0