优化与深度学习
优化与估计
尽管优化方法可以最小化深度学习中的损失函数值,但本质上优化方法达到的目标与深度学习的目标并不相同。
- 优化方法目标:训练集损失函数值
- 深度学习目标:测试集损失函数值(泛化性)
%matplotlib inline import sys sys.path.append('/home/kesci/input') import d2lzh1981 as d2l from mpl_toolkits import mplot3d # 三维画图 import numpy as np
def f(x): return x * np.cos(np.pi * x) def g(x): return f(x) + 0.2 * np.cos(5 * np.pi * x) d2l.set_figsize((5, 3)) x = np.arange(0.5, 1.5, 0.01) fig_f, = d2l.plt.plot(x, f(x),label="train error") fig_g, = d2l.plt.plot(x, g(x),'--', c='purple', label="test error") fig_f.axes.annotate('empirical risk', (1.0, -1.2), (0.5, -1.1),arrowprops=dict(arrowstyle='->')) fig_g.axes.annotate('expected risk', (1.1, -1.05), (0.95, -0.5),arrowprops=dict(arrowstyle='->')) d2l.plt.xlabel('x') d2l.plt.ylabel('risk') d2l.plt.legend(loc="upper right")
<matplotlib.legend.Legend at 0x7fc092436080>
优化在深度学习中的挑战
- 局部最小值
- 鞍点
- 梯度消失
局部最小值
def f(x): return x * np.cos(np.pi * x) d2l.set_figsize((4.5, 2.5)) x = np.arange(-1.0, 2.0, 0.1) fig, = d2l.plt.plot(x, f(x)) fig.axes.annotate('local minimum', xy=(-0.3, -0.25), xytext=(-0.77, -1.0), arrowprops=dict(arrowstyle='->')) fig.axes.annotate('global minimum', xy=(1.1, -0.95), xytext=(0.6, 0.8), arrowprops=dict(arrowstyle='->')) d2l.plt.xlabel('x') d2l.plt.ylabel('f(x)');
鞍点
x = np.arange(-2.0, 2.0, 0.1) fig, = d2l.plt.plot(x, x**3) fig.axes.annotate('saddle point', xy=(0, -0.2), xytext=(-0.52, -5.0), arrowprops=dict(arrowstyle='->')) d2l.plt.xlabel('x') d2l.plt.ylabel('f(x)');
e.g.
x, y = np.mgrid[-1: 1: 31j, -1: 1: 31j] z = x**2 - y**2 d2l.set_figsize((6, 4)) ax = d2l.plt.figure().add_subplot(111, projection='3d') ax.plot_wireframe(x, y, z, **{'rstride': 2, 'cstride': 2}) ax.plot([0], [0], [0], 'ro', markersize=10) ticks = [-1, 0, 1] d2l.plt.xticks(ticks) d2l.plt.yticks(ticks) ax.set_zticks(ticks) d2l.plt.xlabel('x') d2l.plt.ylabel('y');
梯度消失
x = np.arange(-2.0, 5.0, 0.01) fig, = d2l.plt.plot(x, np.tanh(x)) d2l.plt.xlabel('x') d2l.plt.ylabel('f(x)') fig.axes.annotate('vanishing gradient', (4, 1), (2, 0.0) ,arrowprops=dict(arrowstyle='->'))
Text(2, 0.0, 'vanishing gradient')