从零开始学Pytorch(十二)之凸优化

简介: 从零开始学Pytorch(十二)之凸优化

尽管优化方法可以最小化深度学习中的损失函数值,但本质上优化方法达到的目标与深度学习的目标并不相同。

  • 优化方法目标:训练集损失函数值
  • 深度学习目标:测试集损失函数值(泛化性)
%matplotlib inline
import sys
sys.path.append('/home/input')
import d2lzh1981 as d2l
from mpl_toolkits import mplot3d # 三维画图
import numpy as np
def f(x): return x * np.cos(np.pi * x)
def g(x): return f(x) + 0.2 * np.cos(5 * np.pi * x)
d2l.set_figsize((5, 3))
x = np.arange(0.5, 1.5, 0.01)
fig_f, = d2l.plt.plot(x, f(x),label="train error")
fig_g, = d2l.plt.plot(x, g(x),'--', c='purple', label="test error")
fig_f.axes.annotate('empirical risk', (1.0, -1.2), (0.5, -1.1),arrowprops=dict(arrowstyle='->'))
fig_g.axes.annotate('expected risk', (1.1, -1.05), (0.95, -0.5),arrowprops=dict(arrowstyle='->'))
d2l.plt.xlabel('x')
d2l.plt.ylabel('risk')
d2l.plt.legend(loc="upper right")

3f4acdfa7101fb2ee8162d40f5fb875a.png

优化在深度学习中的挑战


  1. 局部最小值
  2. 鞍点
  3. 梯度消失

局部最小值


image.png

def f(x):
    return x * np.cos(np.pi * x)
d2l.set_figsize((4.5, 2.5))
x = np.arange(-1.0, 2.0, 0.1)
fig,  = d2l.plt.plot(x, f(x))
fig.axes.annotate('local minimum', xy=(-0.3, -0.25), xytext=(-0.77, -1.0),
                  arrowprops=dict(arrowstyle='->'))
fig.axes.annotate('global minimum', xy=(1.1, -0.95), xytext=(0.6, 0.8),
                  arrowprops=dict(arrowstyle='->'))
d2l.plt.xlabel('x')
d2l.plt.ylabel('f(x)');

cbb2d4a3e73c443861977e095ef1e154.png

鞍点


x = np.arange(-2.0, 2.0, 0.1)
fig, = d2l.plt.plot(x, x**3)
fig.axes.annotate('saddle point', xy=(0, -0.2), xytext=(-0.52, -5.0),
                  arrowprops=dict(arrowstyle='->'))
d2l.plt.xlabel('x')
d2l.plt.ylabel('f(x)');

4a4f4e992ec63293f12efb2afae13995.png

00e504d1aa9834d9bb8ec787e23eba74.png

x, y = np.mgrid[-1: 1: 31j, -1: 1: 31j]
z = x**2 - y**2
d2l.set_figsize((6, 4))
ax = d2l.plt.figure().add_subplot(111, projection='3d')
ax.plot_wireframe(x, y, z, **{'rstride': 2, 'cstride': 2})
ax.plot([0], [0], [0], 'ro', markersize=10)
ticks = [-1,  0, 1]
d2l.plt.xticks(ticks)
d2l.plt.yticks(ticks)
ax.set_zticks(ticks)
d2l.plt.xlabel('x')
d2l.plt.ylabel('y');

85b4abc3dd806acc04e477ea6a74f9eb.png

梯度消失


x = np.arange(-2.0, 5.0, 0.01)
fig, = d2l.plt.plot(x, np.tanh(x))
d2l.plt.xlabel('x')
d2l.plt.ylabel('f(x)')
fig.axes.annotate('vanishing gradient', (4, 1), (2, 0.0) ,arrowprops=dict(arrowstyle='->'))

e0b14c119de087bd134f824a6a770ef6.png

凸性 (Convexity)


函数


6a15b69c2afad1dcf2eabd208b4844ee.png

def f(x):
    return 0.5 * x**2  # Convex
def g(x):
    return np.cos(np.pi * x)  # Nonconvex
def h(x):
    return np.exp(0.5 * x)  # Convex
x, segment = np.arange(-2, 2, 0.01), np.array([-1.5, 1])
d2l.use_svg_display()
_, axes = d2l.plt.subplots(1, 3, figsize=(9, 3))
for ax, func in zip(axes, [f, g, h]):
    ax.plot(x, func(x))
    ax.plot(segment, func(segment),'--', color="purple")
    # d2l.plt.plot([x, segment], [func(x), func(segment)], axes=ax)

486c90f57c62f5d8f79b426e8414277f.png

Jensen 不等式


fa3eda264c687b654d7bb82397a02d14.png

性质

  1. 无局部极小值
  2. 与凸集的关系
  3. 二阶条件

无局部最小值


image.png

x, y = np.meshgrid(np.linspace(-1, 1, 101), np.linspace(-1, 1, 101),
                   indexing='ij')
z = x**2 + 0.5 * np.cos(2 * np.pi * y)
# Plot the 3D surface
d2l.set_figsize((6, 4))
ax = d2l.plt.figure().add_subplot(111, projection='3d')
ax.plot_wireframe(x, y, z, **{'rstride': 10, 'cstride': 10})
ax.contour(x, y, z, offset=-1)
ax.set_zlim(-1, 1.5)
# Adjust labels
for func in [d2l.plt.xticks, d2l.plt.yticks, ax.set_zticks]:
    func([-1, 0, 1])

829a95103d3634cae3f2b926277d03a4.jpg

凸函数与二阶导数


image.png

image.png

def f(x):
    return 0.5 * x**2
x = np.arange(-2, 2, 0.01)
axb, ab = np.array([-1.5, -0.5, 1]), np.array([-1.5, 1])
d2l.set_figsize((3.5, 2.5))
fig_x, = d2l.plt.plot(x, f(x))
fig_axb, = d2l.plt.plot(axb, f(axb), '-.',color="purple")
fig_ab, = d2l.plt.plot(ab, f(ab),'g-.')
fig_x.axes.annotate('a', (-1.5, f(-1.5)), (-1.5, 1.5),arrowprops=dict(arrowstyle='->'))
fig_x.axes.annotate('b', (1, f(1)), (1, 1.5),arrowprops=dict(arrowstyle='->'))
fig_x.axes.annotate('x', (-0.5, f(-0.5)), (-1.5, f(-0.5)),arrowprops=dict(arrowstyle='->'))

f90584e3f10aa94fa1bc409e31a7c44f.png

限制条件


92efb4637f18256fceb579c44ec81bab.png

参考文献


[1]《动手深度学习》李沐

[2]伯禹教育课程

相关文章
|
6月前
|
机器学习/深度学习 算法 Python
深入浅出Python机器学习:从零开始的SVM教程/厾罗
深入浅出Python机器学习:从零开始的SVM教程/厾罗
|
6月前
|
机器学习/深度学习 算法 PyTorch
从零开始学习线性回归:理论、实践与PyTorch实现
从零开始学习线性回归:理论、实践与PyTorch实现
从零开始学习线性回归:理论、实践与PyTorch实现
|
编解码 固态存储 算法
【项目实践】从零开始学习SSD目标检测算法训练自己的数据集(附注释项目代码)(二)
【项目实践】从零开始学习SSD目标检测算法训练自己的数据集(附注释项目代码)(二)
312 0
|
机器学习/深度学习 固态存储 算法
【项目实践】从零开始学习SSD目标检测算法训练自己的数据集(附注释项目代码)(一)
【项目实践】从零开始学习SSD目标检测算法训练自己的数据集(附注释项目代码)(一)
521 0
|
机器学习/深度学习 算法 决策智能
【项目实践】从零开始学习Deep SORT+YOLO V3进行多目标跟踪(附注释项目代码)(一)
【项目实践】从零开始学习Deep SORT+YOLO V3进行多目标跟踪(附注释项目代码)(一)
271 0
|
机器学习/深度学习 算法 PyTorch
从零开始学Pytorch(十四)之优化算法进阶(二)
从零开始学Pytorch(十四)之优化算法进阶
从零开始学Pytorch(十四)之优化算法进阶(二)
|
算法 PyTorch 算法框架/工具
从零开始学Pytorch(十四)之优化算法进阶(一)
从零开始学Pytorch(十四)之优化算法进阶
从零开始学Pytorch(十四)之优化算法进阶(一)
|
算法 PyTorch 算法框架/工具
从零开始学Pytorch(十七)之目标检测基础(一)
从零开始学Pytorch(十七)之目标检测基础
从零开始学Pytorch(十七)之目标检测基础(一)
|
机器学习/深度学习 PyTorch 算法框架/工具
从零开始学Pytorch(十七)之目标检测基础(二)
从零开始学Pytorch(十七)之目标检测基础
从零开始学Pytorch(十七)之目标检测基础(二)
|
机器学习/深度学习 PyTorch 算法框架/工具
从零开始学Pytorch(十一)之ModernRNN
从零开始学Pytorch(十一)之ModernRNN
从零开始学Pytorch(十一)之ModernRNN