1 算法介绍
2 代码实现
2.1 比较复杂的代码,函数自己写。
# -*- coding:utf-8 -*- from matplotlib import pyplot as plt from mxnet import autograd, nd import matplotlib as mpl import random mpl.rcParams["font.sans-serif"] = ['Fangsong'] mpl.rcParams['axes.unicode_minus'] = False def visualization(features, labels): plt.scatter(features[:, 1].asnumpy(), labels.asnumpy()) plt.show() def data_iter(batch_size, features, labels): num_examples = len(features) indices = list(range(num_examples)) random.shuffle(indices) # 读取样本是随机的 for i in range(0, num_examples, batch_size): j = nd.array(indices[i: min(i + batch_size, num_examples)]) yield features.take(j), labels.take(j) # take函数根据索引返回函数对应元素 def linreg(x, w, b): return nd.dot(x, w) + b def squared_loss(y_hat, y): return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2 def sgd(params, lr, batch_size): for params in params: params[:] = params - lr * params.grad / batch_size if __name__ == '__main__': num_inputs = 2 num_examples = 1000 true_w = [2, -3.4] true_b = 4.2 features = nd.random.normal(scale=1, shape=(num_examples, num_inputs)) labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b # visualization(features, labels) batch_size = 10 # for x, y in data_iter(batch_size, features, labels): # print(x, y) # break w = nd.random.normal(scale=0.01, shape=(num_inputs, 1)) b = nd.zeros(shape=(1,)) w.attach_grad() b.attach_grad() lr = 0.03 num_epoch = 30 net = linreg loss = squared_loss for epoch in range(num_epoch): for x,y in data_iter(batch_size, features, labels): with autograd.record(): l = loss(net(x, w, b), y) l.backward() sgd([w, b], lr, batch_size) train_l = loss(net(features, w, b), labels) print('epoch %d, loss: %lf'%(epoch+1, train_l.mean().asnumpy())) print(true_w, true_b) print(w, b)
2.2 比较简洁的代码,都是调用的包。
# -*- coding:utf-8 -*- from matplotlib import pyplot as plt from mxnet import autograd, nd, init from mxnet.gluon import data as gdata import matplotlib as mpl from mxnet.gluon import nn from mxnet import gluon from mxnet.gluon import loss as gloss import random mpl.rcParams["font.sans-serif"] = ['Fangsong'] mpl.rcParams['axes.unicode_minus'] = False def visualization(features, labels): plt.scatter(features[:, 1].asnumpy(), labels.asnumpy()) plt.show() if __name__ == '__main__': num_inputs = 2 num_examples = 1000 true_w = [2, -3.4] true_b = 4.2 features = nd.random.normal(scale=1, shape=(num_examples, num_inputs)) labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b # visualization(features, labels) batch_size = 10 dataset = gdata.ArrayDataset(features, labels) data_iter = gdata.DataLoader(dataset, batch_size, shuffle=True) lr = 0.03 num_epoch = 30 net = nn.Sequential() net.add(nn.Dense(1)) net.initialize(init.Normal(sigma=0.01)) loss = gloss.L2Loss() trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.03}) for epoch in range(num_epoch): for x, y in data_iter: with autograd.record(): l = loss(net(x), y) l.backward() trainer.step(batch_size) l=loss(net(features), labels) print('epoch %d, loss: %lf' % (epoch + 1, l.mean().asnumpy())) dense = net[0] print(true_w, dense.weight.data()) print(true_b, dense.bias.data())