直接使用
请打开使用Numpy实现卷积神经网络,并点击右上角 “ 在DSW中打开” 。
使用Numpy实现CNN的前向网络以及反向传播
本文用到的数据集来自Kaggle,可以在这里下载
import numpy as np import pandas as pd from matplotlib import pyplot as plt
data = pd.read_csv('train_numpy.csv') data.shape
(2290, 785)
data = np.array(data) m, n = data.shape np.random.shuffle(data) # shuffle before splitting into dev and training sets data_dev = data[0:1000].T Y_dev = data_dev[0] X_dev = data_dev[1:n] X_dev = X_dev / 255. data_train = data[1000:m].T Y_train = data_train[0] X_train = data_train[1:n] X_train = X_train / 255. _,m_train = X_train.shape
本例中,我们定义一个只有两层的卷积神经网络
- Activate function选的是 Relu,因为它可以有效的改善梯度弥散现象
- 最后会通过Softmax将各个神经元的输出结果转化为[0,1]之间的数字
前向传递
反向传递
参数更新的计算公式
根据上面的公式,定义activation function
def init_params(): W1 = np.random.rand(10, 784) - 0.5 b1 = np.random.rand(10, 1) - 0.5 W2 = np.random.rand(10, 10) - 0.5 b2 = np.random.rand(10, 1) - 0.5 return W1, b1, W2, b2 def ReLU(Z): return np.maximum(Z, 0) def ReLU_deriv(Z): return Z > 0 def softmax(Z): A = np.exp(Z) / sum(np.exp(Z)) return A
定义前向传递的函数和反向传递的函数
def forward_prop(W1, b1, W2, b2, X): Z1 = W1.dot(X) + b1 A1 = ReLU(Z1) Z2 = W2.dot(A1) + b2 A2 = softmax(Z2) return Z1, A1, Z2, A2 def one_hot(Y): one_hot_Y = np.zeros((Y.size, int(Y.max()) + 1)).astype(int) one_hot_Y[np.arange(Y.size), Y.astype(int)] = 1 one_hot_Y = one_hot_Y.T return one_hot_Y def backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y): one_hot_Y = one_hot(Y) dZ2 = A2 - one_hot_Y dW2 = 1 / m * dZ2.dot(A1.T) db2 = 1 / m * np.sum(dZ2) dZ1 = W2.T.dot(dZ2) * ReLU_deriv(Z1) dW1 = 1 / m * dZ1.dot(X.T) db1 = 1 / m * np.sum(dZ1) return dW1, db1, dW2, db2 def update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha): W1 = W1 - alpha * dW1 b1 = b1 - alpha * db1 W2 = W2 - alpha * dW2 b2 = b2 - alpha * db2 return W1, b1, W2, b2
定义预测、梯度下降还有计算精确度的函数
def get_predictions(A2): return np.argmax(A2, 0) def get_accuracy(predictions, Y): print(predictions, Y) return np.sum(predictions == Y) / Y.size def gradient_descent(X, Y, alpha, iterations): W1, b1, W2, b2 = init_params() for i in range(iterations): Z1, A1, Z2, A2 = forward_prop(W1, b1, W2, b2, X) dW1, db1, dW2, db2 = backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y) W1, b1, W2, b2 = update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha) if i % 10 == 0: print("Iteration: ", i) predictions = get_predictions(A2) print(get_accuracy(predictions, Y)) return W1, b1, W2, b2
开始训练
W1, b1, W2, b2 = gradient_descent(X_train, Y_train, 0.10, 500)
Iteration: 0 [2 2 2 ... 1 3 2] [9. 7. 2. ... 5. 6. 0.] 0.1263565891472868 Iteration: 10 [1 1 0 ... 1 3 6] [9. 7. 2. ... 5. 6. 0.] 0.1821705426356589 Iteration: 20 [1 1 2 ... 4 0 6] [9. 7. 2. ... 5. 6. 0.] 0.2372093023255814 Iteration: 30 [1 1 6 ... 0 0 6] [9. 7. 2. ... 5. 6. 0.] 0.29147286821705426 Iteration: 40 [1 1 6 ... 0 0 6] [9. 7. 2. ... 5. 6. 0.] 0.3333333333333333 Iteration: 50 [1 1 6 ... 0 0 0] [9. 7. 2. ... 5. 6. 0.] 0.38527131782945734 Iteration: 60 [1 1 6 ... 0 0 0] [9. 7. 2. ... 5. 6. 0.] 0.4325581395348837 Iteration: 70 [1 1 2 ... 0 0 0] [9. 7. 2. ... 5. 6. 0.] 0.4906976744186046 Iteration: 80 [5 7 2 ... 0 0 0] [9. 7. 2. ... 5. 6. 0.] 0.537984496124031 Iteration: 90 [5 7 2 ... 0 0 0] [9. 7. 2. ... 5. 6. 0.] 0.5612403100775194 Iteration: 100 [5 7 2 ... 0 0 0] [9. 7. 2. ... 5. 6. 0.] 0.5844961240310077 Iteration: 110 [5 7 2 ... 0 6 0] [9. 7. 2. ... 5. 6. 0.] 0.6069767441860465 Iteration: 120 [5 7 2 ... 0 6 0] [9. 7. 2. ... 5. 6. 0.] 0.624031007751938 Iteration: 130 [5 7 2 ... 0 6 0] [9. 7. 2. ... 5. 6. 0.] 0.6403100775193798 Iteration: 140 [9 7 2 ... 0 6 0] [9. 7. 2. ... 5. 6. 0.] 0.662015503875969 Iteration: 150 [9 7 2 ... 0 6 0] [9. 7. 2. ... 5. 6. 0.] 0.6713178294573643 Iteration: 160 [9 7 2 ... 0 6 0] [9. 7. 2. ... 5. 6. 0.] 0.6922480620155039 Iteration: 170 [9 7 2 ... 0 6 0] [9. 7. 2. ... 5. 6. 0.] 0.6992248062015504 Iteration: 180 [9 7 2 ... 0 6 0] [9. 7. 2. ... 5. 6. 0.] 0.7116279069767442 Iteration: 190 [9 7 2 ... 0 6 0] [9. 7. 2. ... 5. 6. 0.] 0.7162790697674418 Iteration: 200 [9 7 2 ... 0 6 0] [9. 7. 2. ... 5. 6. 0.] 0.724031007751938 Iteration: 210 [9 7 2 ... 0 6 0] [9. 7. 2. ... 5. 6. 0.] 0.727906976744186 Iteration: 220 [9 7 2 ... 0 6 0] [9. 7. 2. ... 5. 6. 0.] 0.7333333333333333 Iteration: 230 [9 7 2 ... 0 6 0] [9. 7. 2. ... 5. 6. 0.] 0.7395348837209302 Iteration: 240 [9 7 2 ... 0 6 0] [9. 7. 2. ... 5. 6. 0.] 0.7434108527131783 Iteration: 250 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.7550387596899225 Iteration: 260 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.7581395348837209 Iteration: 270 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.7697674418604651 Iteration: 280 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.7751937984496124 Iteration: 290 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.7798449612403101 Iteration: 300 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.7821705426356589 Iteration: 310 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.7883720930232558 Iteration: 320 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.7937984496124031 Iteration: 330 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.7968992248062016 Iteration: 340 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.8007751937984496 Iteration: 350 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.8062015503875969 Iteration: 360 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.8093023255813954 Iteration: 370 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.8147286821705426 Iteration: 380 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.8178294573643411 Iteration: 390 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.8217054263565892 Iteration: 400 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.8224806201550388 Iteration: 410 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.8255813953488372 Iteration: 420 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.8255813953488372 Iteration: 430 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.8263565891472868 Iteration: 440 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.8310077519379845 Iteration: 450 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.8333333333333334 Iteration: 460 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.8348837209302326 Iteration: 470 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.8395348837209302 Iteration: 480 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.8395348837209302 Iteration: 490 [9 7 2 ... 5 6 0] [9. 7. 2. ... 5. 6. 0.] 0.8418604651162791
def make_predictions(X, W1, b1, W2, b2): _, _, _, A2 = forward_prop(W1, b1, W2, b2, X) predictions = get_predictions(A2) return predictions def test_prediction(index, W1, b1, W2, b2): current_image = X_train[:, index, None] prediction = make_predictions(X_train[:, index, None], W1, b1, W2, b2) label = Y_train[index] print("Prediction: ", prediction) print("Label: ", label) current_image = current_image.reshape((28, 28)) * 255 plt.gray() plt.imshow(current_image, interpolation='nearest') plt.show()
下面我们选择几个输入,看一下我们的CNN模型是否能够正确识别图片中的字
test_prediction(0, W1, b1, W2, b2) test_prediction(1, W1, b1, W2, b2) test_prediction(2, W1, b1, W2, b2) test_prediction(3, W1, b1, W2, b2)
Prediction: [9] Label: 9.0
Prediction: [7] Label: 7.0
Prediction: [2] Label: 2.0
Prediction: [7] Label: 7.0
计算一下模型对整个的数据集的精确度
dev_predictions = make_predictions(X_dev, W1, b1, W2, b2) print("Accuracy: ",get_accuracy(dev_predictions, Y_dev))
[6 3 0 4 8 6 8 0 8 4 9 7 5 9 9 8 1 8 8 8 9 9 8 4 6 0 9 4 1 7 5 6 4 0 4 7 0 8 9 1 9 4 8 3 1 1 8 7 9 1 9 7 0 4 9 2 8 5 2 9 9 7 9 5 4 4 2 1 6 9 7 6 2 4 0 9 5 5 1 8 7 6 3 3 8 4 8 4 9 2 2 3 0 9 7 7 6 9 9 9 9 3 4 7 8 9 7 1 1 3 0 4 7 1 4 1 7 8 1 4 9 2 7 6 8 9 3 5 8 9 2 3 0 7 1 9 4 9 0 9 8 2 6 9 6 0 1 3 4 5 4 9 4 0 0 4 8 1 9 0 9 7 8 8 1 5 5 7 8 0 2 5 5 3 6 3 9 3 4 1 0 1 1 9 8 0 9 8 7 0 1 1 4 2 2 6 7 8 2 7 9 1 2 3 1 1 5 2 5 6 8 4 0 2 4 6 6 6 1 8 8 7 7 2 2 3 4 0 6 9 8 0 5 9 4 1 2 7 8 3 7 9 9 7 7 9 1 3 2 0 5 9 3 6 8 8 6 3 1 1 7 3 0 6 6 2 8 1 3 2 8 8 4 8 1 8 1 8 9 7 6 8 7 6 4 4 4 2 4 7 2 9 6 9 1 1 9 8 0 1 0 8 8 9 3 2 1 7 1 8 4 0 9 9 9 2 6 0 3 6 0 4 0 1 7 0 4 2 5 2 2 9 2 7 6 0 9 6 0 8 7 1 7 0 9 8 0 4 5 2 6 6 6 0 6 9 7 5 8 7 3 6 5 8 3 4 7 6 1 2 2 9 9 9 6 3 0 2 9 1 6 7 9 5 1 9 4 9 4 9 2 1 7 6 8 6 3 7 5 0 8 4 3 1 1 0 6 2 4 9 7 1 6 2 8 1 2 0 0 9 8 9 9 8 4 1 0 7 6 0 0 7 1 0 8 9 4 1 7 4 0 6 7 0 7 5 4 1 9 9 6 7 4 1 0 1 6 9 1 8 3 4 1 7 7 1 8 0 2 7 2 0 0 2 7 7 2 3 3 5 3 4 7 3 1 1 2 0 6 1 8 6 0 4 9 1 6 0 2 6 6 6 7 8 5 3 6 4 6 1 3 4 8 6 4 6 6 0 5 5 1 6 4 1 0 8 2 6 0 1 1 3 7 9 7 4 4 7 1 9 3 9 8 9 2 9 0 2 7 1 1 7 1 2 6 4 9 7 3 3 4 9 3 2 2 2 0 0 6 2 1 3 7 4 4 8 1 2 3 7 5 8 8 5 8 2 5 2 3 8 1 0 2 5 8 6 1 0 7 0 3 7 3 6 4 0 5 9 1 2 5 9 8 7 2 2 8 8 4 2 4 5 5 7 7 0 8 2 1 0 3 0 4 5 0 7 9 9 2 4 0 4 6 2 9 2 2 6 9 0 4 0 8 8 2 4 6 1 9 7 5 4 4 6 1 7 1 1 0 0 1 2 3 1 4 6 0 3 6 7 0 6 9 6 3 0 0 1 5 5 3 8 0 1 1 5 8 9 1 4 4 5 6 7 9 6 4 4 8 1 4 0 2 6 2 0 7 4 9 4 8 9 1 9 1 4 0 6 3 7 5 1 0 0 8 0 0 6 2 7 1 5 0 9 4 8 1 7 1 5 1 0 9 5 6 1 2 9 6 5 4 8 9 5 2 8 7 5 2 7 4 3 0 5 5 3 8 5 9 8 1 8 8 5 4 8 0 2 6 1 0 7 4 0 3 7 4 0 0 4 7 6 6 2 5 1 4 7 4 9 3 5 3 2 1 3 0 7 4 3 2 7 2 7 1 3 7 4 3 8 6 8 0 7 6 4 2 4 9 6 0 7 2 7 4 2 8 4 4 8 1 9 4 7 2 2 8 6 0 3 2 8 2 1 4 4 8 4 6 0 0 2 3 6 7 6 8 7 0 4 6 2 1 4 2 1 6 1 2 1 4 5 6 5 9 0 4 9 9 1 9 3 3 5 1 2 1 1 0 9 7 6 2 6 8 3 8 6 0 9 9 2 0 3 0 5 6 4 3 8 6 1 1 6 0 9 4 9 2 4 6 8 9 0 2 3 0 1 1 3 3 5 4 0 1 2 8 0 9 8 6 7 4 8 4 6 0 4 9 6 2 5 9 9 7 6 5 4 1 4 5 0 5 8 4 6 6 0 0 6 1 3 7 0 5 7 4 0 7 2] [6. 5. 0. 4. 8. 6. 8. 2. 5. 4. 9. 7. 0. 9. 9. 8. 1. 3. 8. 8. 9. 9. 8. 8. 6. 0. 9. 4. 1. 7. 5. 6. 4. 0. 4. 7. 0. 8. 9. 1. 5. 4. 8. 3. 1. 1. 8. 7. 9. 1. 4. 7. 0. 4. 3. 8. 5. 5. 2. 9. 9. 7. 9. 3. 6. 4. 2. 1. 6. 9. 9. 6. 2. 4. 0. 9. 5. 0. 1. 5. 7. 6. 3. 8. 8. 4. 8. 4. 4. 2. 2. 3. 0. 9. 7. 9. 6. 9. 7. 4. 9. 3. 8. 7. 8. 4. 7. 1. 1. 3. 0. 4. 7. 2. 4. 1. 7. 3. 1. 8. 9. 2. 7. 6. 8. 9. 3. 5. 8. 9. 2. 3. 5. 7. 3. 4. 4. 4. 0. 9. 8. 2. 6. 9. 6. 0. 1. 3. 4. 5. 4. 7. 4. 2. 0. 4. 5. 1. 1. 0. 9. 7. 1. 8. 1. 5. 5. 7. 3. 3. 2. 5. 5. 3. 6. 0. 9. 3. 7. 1. 0. 1. 1. 2. 8. 0. 9. 8. 7. 0. 1. 1. 4. 2. 2. 6. 9. 8. 2. 8. 9. 1. 5. 3. 1. 1. 5. 2. 5. 4. 6. 4. 0. 2. 4. 6. 6. 6. 2. 3. 8. 7. 7. 2. 2. 3. 4. 0. 6. 5. 8. 0. 5. 9. 4. 1. 2. 7. 1. 9. 3. 4. 4. 9. 7. 9. 1. 3. 2. 0. 5. 9. 9. 6. 8. 8. 6. 8. 1. 1. 3. 3. 0. 6. 6. 2. 8. 1. 3. 2. 5. 9. 4. 8. 1. 4. 1. 8. 8. 7. 6. 3. 7. 8. 4. 4. 4. 2. 4. 9. 2. 9. 6. 9. 1. 1. 4. 8. 0. 1. 0. 8. 8. 7. 3. 2. 1. 7. 1. 8. 4. 0. 9. 9. 4. 2. 6. 0. 3. 6. 0. 4. 6. 2. 9. 0. 5. 2. 1. 2. 2. 9. 2. 7. 6. 0. 4. 6. 2. 8. 7. 1. 7. 0. 4. 8. 3. 4. 5. 2. 6. 6. 6. 0. 6. 9. 7. 5. 8. 7. 3. 6. 5. 8. 8. 4. 9. 6. 1. 2. 2. 9. 9. 9. 2. 3. 0. 2. 9. 1. 6. 7. 9. 3. 1. 9. 4. 9. 7. 9. 2. 5. 7. 6. 8. 8. 3. 7. 5. 0. 8. 4. 3. 1. 1. 0. 4. 2. 4. 9. 9. 1. 6. 2. 5. 1. 2. 0. 0. 9. 8. 9. 9. 3. 4. 1. 0. 7. 6. 5. 0. 7. 1. 0. 8. 9. 4. 2. 7. 4. 0. 6. 7. 5. 7. 3. 4. 1. 9. 7. 6. 9. 4. 1. 7. 1. 0. 9. 1. 8. 3. 4. 1. 7. 7. 1. 8. 0. 0. 7. 4. 0. 0. 2. 9. 7. 2. 3. 3. 5. 3. 4. 9. 3. 1. 1. 2. 0. 4. 1. 8. 6. 0. 4. 9. 1. 5. 0. 2. 6. 6. 5. 7. 8. 5. 3. 2. 4. 9. 2. 3. 9. 9. 6. 4. 6. 6. 0. 5. 5. 1. 6. 4. 3. 0. 8. 2. 6. 0. 1. 1. 3. 7. 9. 7. 4. 9. 5. 1. 9. 3. 9. 8. 9. 2. 9. 0. 2. 7. 1. 5. 7. 2. 2. 6. 4. 4. 7. 3. 3. 4. 7. 2. 2. 2. 2. 0. 0. 5. 2. 1. 3. 7. 7. 4. 8. 1. 2. 5. 7. 5. 9. 8. 5. 8. 2. 5. 2. 2. 8. 1. 0. 2. 5. 5. 6. 1. 0. 7. 0. 3. 7. 3. 2. 4. 0. 5. 9. 6. 2. 5. 9. 8. 7. 6. 2. 8. 4. 4. 2. 4. 5. 5. 7. 7. 0. 8. 2. 1. 0. 3. 0. 4. 5. 0. 7. 9. 9. 2. 4. 5. 4. 6. 2. 9. 2. 2. 6. 4. 0. 4. 7. 8. 9. 2. 0. 6. 1. 9. 7. 8. 4. 4. 9. 1. 7. 1. 1. 0. 2. 2. 2. 3. 1. 9. 6. 0. 3. 6. 7. 3. 2. 9. 6. 3. 2. 0. 1. 8. 5. 3. 5. 5. 1. 1. 5. 8. 4. 1. 4. 9. 5. 6. 7. 4. 6. 4. 4. 8. 1. 4. 0. 2. 6. 2. 0. 7. 4. 9. 4. 8. 9. 1. 9. 1. 4. 0. 6. 3. 7. 5. 1. 0. 6. 5. 0. 0. 6. 2. 2. 1. 5. 0. 4. 4. 8. 1. 8. 1. 3. 5. 0. 9. 3. 6. 8. 2. 9. 6. 5. 6. 8. 9. 5. 2. 1. 9. 5. 2. 7. 4. 3. 5. 5. 3. 3. 8. 5. 4. 8. 1. 5. 8. 5. 5. 8. 0. 3. 6. 1. 0. 7. 9. 0. 7. 7. 4. 6. 0. 4. 7. 6. 3. 2. 5. 1. 9. 7. 4. 3. 3. 5. 3. 2. 1. 3. 0. 7. 4. 3. 2. 7. 2. 7. 1. 8. 7. 4. 3. 5. 6. 8. 3. 7. 6. 4. 2. 4. 4. 6. 0. 7. 2. 7. 6. 2. 5. 4. 4. 8. 2. 9. 4. 3. 2. 2. 8. 6. 0. 3. 2. 8. 2. 1. 4. 4. 9. 4. 6. 5. 0. 2. 0. 6. 7. 2. 8. 5. 3. 5. 6. 2. 7. 6. 2. 1. 6. 1. 2. 1. 4. 5. 6. 5. 9. 0. 4. 9. 4. 1. 9. 3. 8. 5. 1. 8. 1. 1. 2. 8. 5. 6. 2. 6. 8. 3. 2. 6. 0. 4. 4. 2. 0. 3. 0. 5. 6. 6. 3. 8. 7. 1. 1. 6. 0. 9. 4. 9. 5. 4. 2. 5. 9. 0. 2. 3. 0. 1. 1. 3. 3. 5. 9. 0. 7. 2. 8. 6. 7. 8. 4. 7. 4. 8. 4. 6. 0. 4. 4. 6. 2. 4. 4. 9. 7. 6. 5. 4. 1. 9. 4. 0. 8. 8. 4. 1. 6. 0. 0. 6. 6. 3. 4. 0. 3. 7. 4. 0. 9. 8.] Accuracy: 0.784