機械学習のweight decayは、重みの2乗ノルム(L2ノルム)を損失関数に加えること。これによって重みが大きいと損失関数の値が大きくなるので、重みが大きくなりすぎないようになる。過学習は重みが大きくなることで発生することが多いからこういうことする。L2ノルムは、各次元の値の2乗の和。
サンプル(TensorFlow使ってない)
weight_decay = 0 for idx in range(1, self.hidden_layer_num + 2): W = self.params['W' + str(idx)] weight_decay += 0.5 * self.weight_decay_lambda * np.sum(W ** 2) return self.last_layer.forward(y, t) + weight_decay
サンプル(TensorFlow)
def _variable_with_weight_decay(name, shape, stddev, wd): """Helper to create an initialized Variable with weight decay. Note that the Variable is initialized with a truncated normal distribution. A weight decay is added only if one is specified. Args: name: name of the variable shape: list of ints stddev: standard deviation of a truncated Gaussian wd: add L2Loss weight decay multiplied by this float. If None, weight decay is not added for this Variable. Returns: Variable Tensor """ dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 var = _variable_on_cpu( name, shape, tf.truncated_normal_initializer(stddev=stddev, dtype=dtype)) if wd is not None: weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss') tf.add_to_collection('losses', weight_decay) return var
上記の、_variable_on_cpu()は、下記。
def _variable_on_cpu(name, shape, initializer): """Helper to create a Variable stored on CPU memory. Args: name: name of the variable shape: list of ints initializer: initializer for Variable Returns: Variable Tensor """ with tf.device('/cpu:0'): dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype) return var
そして、_variable_with_weight_decay()は下記のように使われる。
kernel = _variable_with_weight_decay('weights', shape=[5, 5, 3, 64], stddev=5e-2, wd=0.0)
(tf.nn.l2_loss()
多分2乗した合計かその半分が出てくるんじゃないかと思うので、確かめてみる。
import tensorflow as tf wd = tf.constant(0.01) W = tf.constant([1., 2., 3.]) l2 = tf.nn.l2_loss(W) weight_decay = tf.mul(l2, wd) with tf.Session() as sess: l2, weight_decay = sess.run([l2, weight_decay]) print(l2) print(weight_decay)
結果
7.0 0.07
2乗和の半分っぽい。
TensorFlowのサンプルのように、wdを0.0にしてるってことは、weight decayに含めないってことか。畳み込み層は含まず、全結合層だけwdを0.004にしてる。
TensorFlowで損失関数にL2ノルムのweight decayを足してるところ
def loss(logits, labels): """Add L2Loss to all the trainable variables. Add summary for "Loss" and "Loss/avg". Args: logits: Logits from inference(). labels: Labels from distorted_inputs or inputs(). 1-D tensor of shape [batch_size] Returns: Loss tensor of type float. """ # Calculate the average cross entropy loss across the batch. labels = tf.cast(labels, tf.int64) cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits, name='cross_entropy_per_example') cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy') tf.add_to_collection('losses', cross_entropy_mean) # The total loss is defined as the cross entropy loss plus all of the weight # decay terms (L2 loss). return tf.add_n(tf.get_collection('losses'), name='total_loss')
バッチ処理してるので、cross_entropyの平均をとって、weigt decayが入った'losses'というコレクションに入れて、コレクション内の数字を全部足している。