なんもわからんな

Shake Shakeの実装と簡単な解説

前書き

最近画像系コンペに再び参加してみようと思い,既存の分類精度が高いモデルをもう一度調べたいと思ったので,メモ程度に残したいと思います.

Shake Shakeとは

Shake ShakeとはResNetの一種で,中間層でdata augmentationをしており正則化をしています. shake shakeのモデル図は,以下のように2つに分岐しており分岐しています.
foward時はそれぞれの最後でαi ∈ [0, 1] を乗算しbackward時はαiとは異なる βi ∈ [0, 1] を使用します.またtest時は0.5で固定してやるそうです. forward時のこれは,画像に含まれる物体の割合が変化してもロバストに識別ができるように学習ができるようです.
backward時は,勾配にノイズを加えると精度が向上するためであり,αiと違う乱数を用いいることでさらに強い正則化効果を持たすことができるようです. f:id:yakuta55:20180131100207p:plain

また論文にはThe skip connections represent the identity function except during downsampling where a slightly customized structure consisting of 2 concatenated flows is used. Each of the 2 flows has the following components: 1x1 average pooling with step 2 followed by a 1x1 convolution. The input of one of the two flows is shifted by 1 pixel right and 1 pixel down to make the average pooling sample from a different position. The concatenation of the two flows doubles the width. のような記述があり,skipをダウンサンプリングする場合にはskipをの1つは右に1下に1ずらす必要がありそうです.その後に1x1average pooligと2つの畳み込みをした後にもう1つも同様にpoolingと畳み込みをしたものと結合するようです.

実装

実装はTensorflowで書きました.
Residual Blockの部分を抜粋して載せておきます.

import tensorflow as tf

def residual_block(x, a, filter_size, stride):
    def convolution(x, l):
        h = tf.nn.relu(x)
        h = tf.layers.conv2d(h, filter_size, [3,3], [stride, stride], padding="SAME")
        h = tf.nn.relu(tf.layers.batch_normalization(h))
        h = tf.layers.conv2d(h, filter_size, [3,3], padding="SAME")
        h = tf.layers.batch_normalization(h)
        return h  * (a if l else 1-a)
    
    def down_sampling(x):
        x = tf.nn.relu(x)
        h1 = tf.layers.average_pooling2d(x, [1,1], [2,2])
        h1 = tf.layers.conv2d(h1, filter_size/2, [1,1], padding="SAME")
        h2 = tf.pad(x[:, 1:, 1:] ,[[0,0], [0,1], [0,1], [0,0]])
        h2 = tf.layers.average_pooling2d(h2, [1,1], [2,2])
        h2 = tf.layers.conv2d(h2, filter_size/2, [1,1], padding="SAME")
        return tf.concat([h1, h2], axis=-1)

    lbranch = convolution(x, True)
    rbranch = convolution(x, False)
    branch = lbranch + rbranch

    if not x.get_shape().as_list()[-1] == filter_size:
        x = down_sampling(x)
    else:
        x = tf.identity(x, name='x')

    return x + branch 

参考文献

Shake-Shake regularization

[サーベイ論文] 畳み込みニューラルネットワークの研究動向
公式の実装