NAS with RL

2017-ICLR-Neural Architecture Search with Reinforcement Learning

  • Google Brain
  • Quoc V . Le etc
  • GitHub: stars
  • Citation:1499

Abstract

用RNN生成模型描述(边长的字符串),用RL(强化学习)训练RNN,来最大化模型在验证集上的准确率。

Motivation

深度学习的成功是因为范式的转变:特征设计(SIFT、HOG)到结构设计(AlexNet、VGGNet)。

2017-ICLR-NAS_with_RL-Neural Architecture Search with Reinforcement Learning-论文阅读-LMLPHP

控制器RNN生成很多网络结构(用变长字符串描述),以p的概率采样出结构A,训练网络A,得到准确率R,计算p的梯度,and scale it by R* to update the controller(RNN).

观察到网络结构和连接可以可以表示为变长的字符串。

变长字符串可以用RNN(控制器)来生成。

训练特定的字符串(子网络),在验证集上,得到验证集准确率。

使用验证集准确率作为奖励,更新控制器RNN。

在下一个迭代中,控制器RNN生成准确率高的结构的概率会更大。就是说控制器RNN会学会不断改进搜索(生成)策略,以生成更好地结构。

Contribution

  • 将卷积网络结构描述为变长的字符串
  • 使用RNN控制器来生成变长的字符串
  • 使用准确率来更新RNN使得生成的结构(字符串)质量越来越高

Method

3.1 GENERATE MODEL DESCRIPTIONS WITH A CONTROLLER RECURRENT NEURAL NETWORK

设我们要预测(生成/搜索)的前向网络是卷积网络,我们可以用控制器RNN来生成每一层的超参数(序列):(卷积核高、宽,stride 高、宽,卷积核数量)五元组

根据经验,层数超过特定值的时候,生成结构的过程将会停止。**层数从少到多?最后都是生成指定层的结构?

这个指定值随着训练过程逐渐增加。

一旦控制器RNN完成一个结构的生成,该结构的网络已经建立并且被训练完毕**。

(子网络训练**)收敛时,记录验证集上的准确率。

根据验证集准确率更新控制器RNN,参数θc。

下一部分继续阐述如何根据梯度策略更新控制器RNN的参数θc

3.2 Training with Reinforce

控制器RNN生成的代表子网络的序列可以写为\(a_{1:T}\).

子网络训练到收敛时,会得到在验证集上的准确率R

我们可以使用R作为训练RNN控制器的奖励

具体地,我们让控制器最大化奖励R的期望,期望可以表示为\(J(θ_c)\):

\(J\left(\theta_{c}\right)=E_{P\left(a_{1: T} ; \theta_{c}\right)}[R]\).

️ **如何计算R的期望?\(P\left(a_{1: T} ; \theta_{c}\right)\),是什么?

由于R是不可微的,所以我们使用梯度方法来迭代更新\(θ_c\).

️ **\(P\left(a_{t} | a_{(t-1): 1} ; \theta_{c}\right)\).是什么?\(\sum_{t=1}^{T}\).又是什么?

上述等式右边根据经验近似为:

️ 怎么近似的?

公式中m是不同结构的数量,T是控制结构的序列的长度(超参的数量)

\(R_k\)是第k个结构的训练精度

以上是梯度的无偏估计,但️方差较大?,我们将其剪去baseline

\(\frac{1}{m} \sum_{k=1}^{m} \sum_{t=1}^{T} \nabla_{\theta_{c}} \log P\left(a_{t} | a_{(t-1): 1} ; \theta_{c}\right)\left(R_{k}-b\right)\)

只要baseline不依赖当前值,就仍然是无偏估计

具体地,baseline的值b为先前结构的指数移动平均值(EMA)

️ 每次训练一个子网络到收敛时才更新控制器RNN的梯度?

训练一个子网络花费几个小时,我们使用分布式训练来加速控制器RNN的学习

我们使用parameter-server的策略进行分布式训练....

3.3 Increase Architecture Complexity Skip Connections and Other Layer Types

在3.1节中,搜索空间只有卷积层,没有skip connection(ResNet),branching layers(GoogLeNet)

这一节中,我们允许控制器RNN提出skip connections 和 branch layers,即扩大搜索空间

为了让控制器RNN预测这些新的连接,我们使用了一种​ ️ 注意力机制(集合选择型注意力?)

在第N层,我们添加N-1个anchor point ️ ,anchor point是基于content 的sigmoids 函数,来指示之前的N-1个层是否需要连接到当前层

每个sigmoid函数是控制器RNN当前隐藏状态 和 之前N-1个anchor points隐藏状态的函数,第 \(i/N\) 层的sigmoid函数可以表示为:

式中 \(h_j\) 表示控制器RNN第 \(j\) 层anchor point的隐藏状态,\(j∈[0, N-1]\)

我们从这些sigmoids中采样,以决定将先前的哪个层作为当前层的输入

在这些连接中,我们一样定义概率分布,强化(学习)的方法依然应用,无需额外修改

如果有多个input layer,那么这些input在depth维度上concatenated

skip connections会导致concatenated失败,比如不同层的output维度不同、一个层没有input或没有output,为了解决这个问题,我们使用了3个技术

一,如果一个层没有input layer,那么把image作为input layer

二,在最后一层,我们将之前所有没有output layer的层的outputs concatenate,作为最后一层的输入/ ️ 隐藏状态?

三,如果需要concatenate的多个input layers的维度不同,用zeros padding小的input使维度统一

在3.1节中,我们不预测learning rate,且假设网络只包含卷积层,限制很严格

加上对learning rate的预测

此外,也可以加上对pooling,LCN(局部对比度归一化),bn层的预测

为了增加更多层类型,我们需要在控制器RNN增加额外的步骤,以及额外的超参数(来表示新的层)

Experiments

搜索空间:卷积结构,包含ReLU、BN、skip connections

同时在800块GPU上训练

过去5个epochs的最高测试集精度作为更新控制器RNN的奖励

从5w个训练集样本中抽5000个作为验证集,剩下45000个作为训练集

定义Optimizer,weight decay,momentum

随着训练过程进行,逐渐增加网络层数

层数从2开始,每1600个子网络层数增加2

共训练了12800个子网络

网格搜索结构的超参数:learning rate, weight decay, batchnorm epsilon,lr进行weight decay的epoch数

2017-ICLR-NAS_with_RL-Neural Architecture Search with Reinforcement Learning-论文阅读-LMLPHP

不预测stride(stride fix to1)和pooling的15层卷积网络,err rate:5.50

该网络的优点:深度最浅,计算量最小

This architecture is shown in Appendix A, Figure 7.

2017-ICLR-NAS_with_RL-Neural Architecture Search with Reinforcement Learning-论文阅读-LMLPHP

观察该结构,1.有很多矩形卷积核(️ 矩形卷积核?)2.越深的层偏爱大卷积核 3.有很多skip connections

该结构只是局部最优,对结构参数(字符串)进行微小扰动的话,会降低网络表现

In the second set of experiments, we ask the controller to predict strides in addition to other hyperparameters.

另一组实验,(stride in [1 2 3])

找到一个20层的结构,err rate:6.01,比第一组实验还差

允许引入2个pooling层(分别在第13和24层),设计39层的网络,err rate:4.47

为了限制搜素空间复杂度,我们搜索13层,每层都是3层的full connected block(如下,网上找的图) 的结构( ️ 一共19层?)

2017-ICLR-NAS_with_RL-Neural Architecture Search with Reinforcement Learning-论文阅读-LMLPHP

改变每层filter num的搜索范围从[24 36 48 64] 改为 [6 12 24 36]

每一层额外添加40+个filters,err rate:3.65%,比DenseNet 的 3.74% 更好

Conclusion

介绍了NAS

使用RNN控制器,灵活地搜索变长的结构搜索空间

搜索到的结构有很强的性能

Appendix

05-11 16:12