神经网络架构搜索——可微分搜索(SGAS)

简介: KAUST&Intel发表在CVPR 2020上的NAS工作,针对现有DARTS框架**在搜索阶段具有高验证集准确率的架构可能在评估阶段表现不好**的问题,提出了分解神经网络架构搜索过程为**一系列子问题**,SGAS使用**贪婪策略选择并剪枝候选操作**的技术,在搜索CNN和GCN网络架构均达到了SOTA。- Paper: SGAS: Sequential Greedy Architecture Search- Code: https://github.com/lightaime/sgas

动机

NAS技术都有一个通病:在搜索过程中验证精度较高,但是在实际测试精度却没有那么高。传统的基于梯度搜索的DARTS技术,是根据block构建更大的超网,由于搜索的过程中验证不充分,最终eval和test精度会出现鸿沟。从下图的Kendall系数来看,DARTS搜出的网络精度排名和实际训练完成的精度排名偏差还是比较大。

"Accuracy GAP"

方法

整体思路

本文使用与DARTS相同的搜索空间,SGAS搜索过程简单易懂,如下图所示。类似DARTS搜索过程为每条边指定参数α,超网训练时通过文中判定规则逐渐确定每条边的具体操作,搜索结束后即可得到最终模型。

SGAS架构示意图

算法伪代码

为了保证在贪心搜索的过程中能尽量保证搜索的全局最优性,进而引入了三个指标两个评估准则

三个指标

边的重要性

非零操作参数对应的softmax值求和,作为边的重要性衡量指标。

$$ S_{E I}^{(i, j)}=\sum_{o \in \mathcal{O}, o \neq z e r o} \frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)} $$

alphas = []
for i in range(4):
    for n in range(2 + i):
        alphas.append(Variable(1e-3 * torch.randn(8)))
# alphas经过训练后
mat = F.softmax(torch.stack(alphas, dim=0), dim=-1).detach() # mat为14*8维度的二维列表,softmax归一化。 
EI = torch.sum(mat[:, 1:], dim=-1) # EI为14个数的一维列表,去掉none后的7个ops对应alpha值相加
选择的准确性

计算操作分布的标准化熵,熵越小确定性越高;熵越高确定性越小。

$$ \begin{array}{c} p_{o}^{(i, j)}=\frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{S_{E I}^{(i, j)} \sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)}, o \in \mathcal{O}, o \neq z e r o \\ S_{S C}^{(i, j)}=1-\frac{-\sum_{o \in \mathcal{O}, o \neq z e r o} p_{o}^{(i, j)} \log \left(p_{o}^{(i, j)}\right)}{\log (|\mathcal{O}|-1)} \end{array} $$

import torch.distributions.categorical as cate
probs = mat[:, 1:] / EI[:, None]
entropy = cate.Categorical(probs=probs).entropy() / math.log(probs.size()[1])
SC = 1-entropy
选择的稳定性

将历史信息纳入操作分布评估,使用直方图交叉核计算平均选择稳定性。直方图交叉核的原理详见(https://blog.csdn.net/hong__fang/article/details/50550656)。

$$ S_{S S}^{(i, j)}=\frac{1}{K} \sum_{t=T-K}^{T-1} \sum_{o_{t} \in \mathcal{O}, o_{t} \neq z e r o} \min \left(p_{o_{t}}^{(i, j)}, p_{o_{T}}^{(i, j)}\right) $$

def histogram_intersection(a, b):
  c = np.minimum(a.cpu().numpy(),b.cpu().numpy())
  c = torch.from_numpy(c).cuda()
  sums = c.sum(dim=1)
  return sums

def histogram_average(history, probs):
  histogram_inter = torch.zeros(probs.shape[0], dtype=torch.float).cuda()
  if not history:
    return histogram_inter
  for hist in history:
    histogram_inter += utils.histogram_intersection(hist, probs)
  histogram_inter /= len(history)
  return histogram_inter

probs_history = []

probs_history.append(probs)
if (len(probs_history) > args.history_size):
  probs_history.pop(0)
  
histogram_inter = histogram_average(probs_history, probs)

SS = histogram_inter

两种评估准则

评估准则1:

选择具有高边缘重要性和高选择确定性的操作

$$ S_{1}^{(i, j)}=\text { normalize }\left(S_{E I}^{(i, j)}\right) * \text { normalize }\left(S_{S C}^{(i, j)}\right) $$

def normalize(v):
  min_v = torch.min(v)
  range_v = torch.max(v) - min_v
  if range_v > 0:
    normalized_v = (v - min_v) / range_v
  else:
    normalized_v = torch.zeros(v.size()).cuda()

  return normalized_v

score = utils.normalize(EI) * utils.normalize(SC)
评估准则2:

在评估准则1的基础上,加入考虑选择稳定性

$$ S_{2}^{(i, j)}=S_{1}^{(i, j)} * \text { normalize }\left(S_{S S}^{(i, j)}\right) $$

score = utils.normalize(EI) * utils.normalize(SC) * utils.normalize(SS)

实验结果

CIFAR-10(CNN)

CIFAR-10(CNN)

ImageNet(CNN)

ImageNet(CNN)

ModelNet40(GCN)

ModelNet40(GCN)

PPI(GCN)

PPI(GCN)

参考

[1] Li, Guohao et al. ,SGAS: Sequential Greedy Architecture Search

[2] https://zhuanlan.zhihu.com/p/134294068

[3] 直方图交叉核 https://blog.csdn.net/hong__fang/article/details/50550656


![更多内容关注微信公众号【AI异构】]

目录
相关文章
|
6天前
|
安全 网络安全
现代化企业网络安全架构设计与实践
随着企业信息化程度的提升,网络安全问题日益凸显。本文从企业网络安全架构设计与实践的角度出发,探讨了现代化企业网络安全的重要性、设计原则和实施方法,并结合具体案例进行分析,为企业构建健壮的网络安全体系提供了参考和指导。
|
6天前
|
消息中间件 存储 缓存
Kafka【基础知识 01】消息队列介绍+Kafka架构及核心概念(图片来源于网络)
【2月更文挑战第20天】Kafka【基础知识 01】消息队列介绍+Kafka架构及核心概念(图片来源于网络)
129 2
|
6天前
|
Cloud Native Linux 网络虚拟化
深入理解Linux veth虚拟网络设备:原理、应用与在容器化架构中的重要性
在Linux网络虚拟化领域,虚拟以太网设备(veth)扮演着至关重要的角色🌐。veth是一种特殊类型的网络设备,它在Linux内核中以成对的形式存在,允许两个网络命名空间之间的通信🔗。这篇文章将从多个维度深入分析veth的概念、作用、重要性,以及在容器和云原生环境中的应用📚。
深入理解Linux veth虚拟网络设备:原理、应用与在容器化架构中的重要性
|
6天前
|
运维 监控 安全
|
4天前
|
安全 网络协议 网络安全
网络安全笔记整理,你花了多久弄明白架构设计
网络安全笔记整理,你花了多久弄明白架构设计
|
6天前
|
机器学习/深度学习 安全 网络安全
数字堡垒的构筑者:网络安全与信息安全的深层剖析构建高效微服务架构:后端开发的新趋势
【4月更文挑战第30天】在信息技术高速发展的今天,构建坚不可摧的数字堡垒已成为个人、企业乃至国家安全的重要组成部分。本文深入探讨网络安全漏洞的本质、加密技术的进展以及提升安全意识的必要性,旨在为读者提供全面的网络安全与信息安全知识框架。通过对网络攻防技术的解析和案例研究,我们揭示了防御策略的关键点,并强调了持续教育在塑造安全文化中的作用。
|
6天前
|
人工智能 安全 大数据
SDN(软件定义网络)——重塑网络架构的新视角
SDN(软件定义网络)是网络架构革新的关键,通过分离控制与数据平面,实现网络的灵活、高效管理。未来,SDN将更广泛应用于各行业,与云计算、大数据、AI融合,推动数字化转型。开放与标准化的趋势将促进SDN生态发展,提供以业务需求为导向、智能化自动化管理及增强网络安全的新视角。SDN将在更多领域扮演重要角色,支持网络技术的创新与进步。
|
6天前
|
机器学习/深度学习 自然语言处理 网络架构
经典神经网络架构参考 v1.0(4)
经典神经网络架构参考 v1.0
24 0
|
6天前
|
网络架构
经典神经网络架构参考 v1.0(3)
经典神经网络架构参考 v1.0
19 0