NAS with RL
2017-ICLR-Neural Architecture Search with Reinforcement Learning
- Google Brain
- Quoc V . Le etc
- GitHub: stars
- Citation:1499
Abstract
用RNN生成模型描述(边长的字符串),用RL(强化学习)训练RNN,来最大化模型在验证集上的准确率。
Motivation
深度学习的成功是因为范式的转变:特征设计(SIFT、HOG)到结构设计(AlexNet、VGGNet)。
控制器RNN生成很多网络结构(用变长字符串描述),以p的概率采样出结构A,训练网络A,得到准确率R,计算p的梯度,and scale it by R* to update the controller(RNN).
观察到网络结构和连接可以可以表示为变长的字符串。
变长字符串可以用RNN(控制器)来生成。
训练特定的字符串(子网络),在验证集上,得到验证集准确率。
使用验证集准确率作为奖励,更新控制器RNN。
在下一个迭代中,控制器RNN生成准确率高的结构的概率会更大。就是说控制器RNN会学会不断改进搜索(生成)策略,以生成更好地结构。
Contribution
- 将卷积网络结构描述为变长的字符串
- 使用RNN控制器来生成变长的字符串
- 使用准确率来更新RNN使得生成的结构(字符串)质量越来越高
Method
3.1 GENERATE MODEL DESCRIPTIONS WITH A CONTROLLER RECURRENT NEURAL NETWORK
设我们要预测(生成/搜索)的前向网络是卷积网络,我们可以用控制器RNN来生成每一层的超参数(序列):(卷积核高、宽,stride 高、宽,卷积核数量)五元组
根据经验,层数超过特定值的时候,生成结构的过程将会停止。**层数从少到多?最后都是生成指定层的结构?
这个指定值随着训练过程逐渐增加。
一旦控制器RNN完成一个结构的生成,该结构的网络已经建立并且被训练完毕**。
(子网络训练**)收敛时,记录验证集上的准确率。
根据验证集准确率更新控制器RNN,参数θc。
下一部分继续阐述如何根据梯度策略更新控制器RNN的参数θc
3.2 Training with Reinforce
控制器RNN生成的代表子网络的序列可以写为\(a_{1:T}\).
子网络训练到收敛时,会得到在验证集上的准确率R
我们可以使用R作为训练RNN控制器的奖励
具体地,我们让控制器最大化奖励R的期望,期望可以表示为\(J(θ_c)\):
\(J\left(\theta_{c}\right)=E_{P\left(a_{1: T} ; \theta_{c}\right)}[R]\).
️ **如何计算R的期望?\(P\left(a_{1: T} ; \theta_{c}\right)\),是什么?
由于R是不可微的,所以我们使用梯度方法来迭代更新\(θ_c\).
️ **\(P\left(a_{t} | a_{(t-1): 1} ; \theta_{c}\right)\).是什么?\(\sum_{t=1}^{T}\).又是什么?
上述等式右边根据经验近似为:
️ 怎么近似的?
公式中m是不同结构的数量,T是控制结构的序列的长度(超参的数量)
\(R_k\)是第k个结构的训练精度
以上是梯度的无偏估计,但️方差较大?,我们将其剪去baseline
\(\frac{1}{m} \sum_{k=1}^{m} \sum_{t=1}^{T} \nabla_{\theta_{c}} \log P\left(a_{t} | a_{(t-1): 1} ; \theta_{c}\right)\left(R_{k}-b\right)\)
只要baseline不依赖当前值,就仍然是无偏估计
具体地,baseline的值b为先前结构的指数移动平均值(EMA)
️ 每次训练一个子网络到收敛时才更新控制器RNN的梯度?
训练一个子网络花费几个小时,我们使用分布式训练来加速控制器RNN的学习
我们使用parameter-server的策略进行分布式训练....
3.3 Increase Architecture Complexity Skip Connections and Other Layer Types
在3.1节中,搜索空间只有卷积层,没有skip connection(ResNet),branching layers(GoogLeNet)
这一节中,我们允许控制器RNN提出skip connections 和 branch layers,即扩大搜索空间
为了让控制器RNN预测这些新的连接,我们使用了一种 ️ 注意力机制(集合选择型注意力?)
在第N层,我们添加N-1个anchor point ️ ,anchor point是基于content 的sigmoids 函数,来指示之前的N-1个层是否需要连接到当前层
每个sigmoid函数是控制器RNN当前隐藏状态 和 之前N-1个anchor points隐藏状态的函数,第 \(i/N\) 层的sigmoid函数可以表示为:
式中 \(h_j\) 表示控制器RNN第 \(j\) 层anchor point的隐藏状态,\(j∈[0, N-1]\)
我们从这些sigmoids中采样,以决定将先前的哪个层作为当前层的输入
在这些连接中,我们一样定义概率分布,强化(学习)的方法依然应用,无需额外修改
如果有多个input layer,那么这些input在depth维度上concatenated
skip connections会导致concatenated失败,比如不同层的output维度不同、一个层没有input或没有output,为了解决这个问题,我们使用了3个技术
一,如果一个层没有input layer,那么把image作为input layer
二,在最后一层,我们将之前所有没有output layer的层的outputs concatenate,作为最后一层的输入/ ️ 隐藏状态?
三,如果需要concatenate的多个input layers的维度不同,用zeros padding小的input使维度统一
在3.1节中,我们不预测learning rate,且假设网络只包含卷积层,限制很严格
加上对learning rate的预测
此外,也可以加上对pooling,LCN(局部对比度归一化),bn层的预测
为了增加更多层类型,我们需要在控制器RNN增加额外的步骤,以及额外的超参数(来表示新的层)
Experiments
搜索空间:卷积结构,包含ReLU、BN、skip connections
同时在800块GPU上训练
过去5个epochs的最高测试集精度作为更新控制器RNN的奖励
从5w个训练集样本中抽5000个作为验证集,剩下45000个作为训练集
定义Optimizer,weight decay,momentum
随着训练过程进行,逐渐增加网络层数
层数从2开始,每1600个子网络层数增加2
共训练了12800个子网络
网格搜索结构的超参数:learning rate, weight decay, batchnorm epsilon,lr进行weight decay的epoch数
不预测stride(stride fix to1)和pooling的15层卷积网络,err rate:5.50
该网络的优点:深度最浅,计算量最小
This architecture is shown in Appendix A, Figure 7.
观察该结构,1.有很多矩形卷积核(️ 矩形卷积核?)2.越深的层偏爱大卷积核 3.有很多skip connections
该结构只是局部最优,对结构参数(字符串)进行微小扰动的话,会降低网络表现
In the second set of experiments, we ask the controller to predict strides in addition to other hyperparameters.
另一组实验,(stride in [1 2 3])
找到一个20层的结构,err rate:6.01,比第一组实验还差
允许引入2个pooling层(分别在第13和24层),设计39层的网络,err rate:4.47
为了限制搜素空间复杂度,我们搜索13层,每层都是3层的full connected block(如下,网上找的图) 的结构( ️ 一共19层?)
改变每层filter num的搜索范围从[24 36 48 64] 改为 [6 12 24 36]
每一层额外添加40+个filters,err rate:3.65%,比DenseNet 的 3.74% 更好
Conclusion
介绍了NAS
使用RNN控制器,灵活地搜索变长的结构搜索空间
搜索到的结构有很强的性能