NAS 学习笔记(二)- SMASH

  |     |   本文总阅读量:

版权声明:本文原创,转载请留意文尾,如有侵权请留言,谢谢

引言

  SMASH[1] 也是一种 NAS 算法,它可以加速 architecture 的搜索的过程,主要用到的技术就是一个辅助的 HyperNet,它会根据 architecture 动态的生成特定的 weights,而不是 random weights。同时,论文中还提出了一种基于 memory read-writes 的灵活机制,它可以定义各种网络连接模式。

SMASH

  SMASH 算法的伪代码如下图所示:

  SMASH 会首先训练 HyperNet,它会每次随机采样一个 architecture,然后用 HyperNet 生成的 weights 去训练它,接着更新 HyperNet。等到 HyperNet 训练完,SMASH 会采样一些 architecture,用 HyperNet 生成的 weights 去初始化它们,然后 evaluate 它们在验证集上的效果,选择一个最好的 architecture,来正常地训练它的 weights。
  SMASH 有两个重要的问题:

  • 如何采样 architecture,论文就此提出了 Memory-Bank,它存储了 feed-forward 网络的 view,这些 view 允许对复杂的分支拓扑进行采样,并将所采拓扑编码为二进制向量。
  • 如何对于特定的 architecture 采样 weights,论文采用了HyperNet,它可以直接从二进制架构编码映射到权重空间。

Memory-Bank

  什么是 Memory-Bank 呢?相比于将 network 看作是一系列附加到 forward-propagating signal 上的 operation,论文中将 network 看作是一系列的 memory bank(initially tensors filled with zeros),它可以支持读写。每一层所做的 operation 就是从一个 subset of memory 里读出数据,然后修改数据,再将结果写入到另一个 subset of memory 中。下图是几种经典网络的 block 的示意图:

  论文中是拿 CNN 举例,如下图所示,base network 由多个 block 组成,对于给定的 spatial resolution 下每个 block 有一系列的 bank,与大多数CNN架构一样,spatial resolution 不断减半。downsampling 通过一个 \(1 \times 1\) 的卷积和 average pooling 完成,\(1 \times 1\) 的卷积和全连接输出层的权值可以自由学习。

  当 downsampling 一个 architecture 时,在每个 block 中, bank 的数量和 channel 的数量都是是随机采样的。当定义 block 的每层时,我们随机选择 read-write pattern 并且在读取数据时随机选择 operation。当从多个 bank 中读取时,我们沿着channel axis 来 concatenate 读取的张量,当写入 bank 时,我们把当前每个存储体的张量加起来。
  每个 operation 由 \(1 \times 1\) 卷积(用了减少输入通道数)和数量不等的带非线性的卷积组成。如下图所示,我们随机选择 4 个激活的卷积,以及它们的 filter size,dilation factor,number of groups 和 number of output units。\(1 \times 1\) 卷积的输出通道数是 operation 输出通道数的 bottleneck ratio。

  \(1 \times 1\) 卷积的 weights 是由 HyperNet 提供的,而其它卷积的 weights 则是可学习的。

Dynamic HyperNet

  HyperNet 是用来给其它 network 产生 weighs的,相比于 Static HyperNet,Dynamic HyperNet 能基于 architecture \(c\) 编码的张量产生 weights \(W=H(c)\),我们的目标也就是是去学习这个 mapping,对于任何给定的输入都可以得到合理的接近最优的 \(W\),所以我们能够基于验证集上的误差排序每个 architecture。因此,我们采用了一种 \(c\) 的排布策略,以便能够对拓扑结构进行采样,并与标准库中的 toolbox 兼容,并使 \(c\) 的维度尽可能具有可解释性。
  我们的HyperNet是全卷积的,以至于输出张量 \(W\) 的维度随着输入 \(c\) 的维度变化,我们得到标准格式 BCHW 的 4D 张量,批量大小为 1,这样没有输出元素是完全独立的。这允许我们通过增加 \(c\) 的高度或宽度来改变主要网络的深度和宽度。根据这一策略, \(W\) 的每一片空间维度对应于 \(c\) 的一个特定子集。operation 的信息通过 \(W\) 子集嵌入在通道维度相应的 \(c\) 片来描述的。

Conclusion

  SMASH 大大减少了算力成本,它通过随机采样 architecture,在训练集上通过 SGD 学习 HyperNet 的 weights。并使用 HyperNet 生成的 weights 初始化 child network,比每次随机初始化 weights 并从头训练 child network 减少了大量算力成本。
  然而,由于需要设计独立的 HyperNet 来产生 search space 中所有可能 architecture 的 weights,然而,HyperNet 的设计需要精细的专业知识,才能在采样模型的真实性能和生成的权值之间强相关。   

Refer

[1]. Brock, Andrew, et al. "SMASH: one-shot model architecture search through hypernetworks." arXiv preprint arXiv:1708.05344 (2017).

相关内容


坚持原创技术分享,您的支持将鼓励我继续创作,π(3.14)元就够啦!



文章目录
  1. 1. 引言
  2. 2. SMASH
  3. 3. Memory-Bank
  4. 4. Dynamic HyperNet
  5. 5. Conclusion
  6. 6. Refer
  7. 7. 相关内容
您是第 位小伙伴 | 本站总访问量 | 已经写了 670.5k 字啦

载入天数...载入时分秒...