Logicky BLOG

Logickyの開発ブログです

  • Javascript
  • Python
  • PHP
  • Go
  • OS・サーバ
  • 機械学習
  • つくったもの
  • 数学
  • アルゴリズム
  • Logicky

TensorFlow - weight decay

機械学習のweight decayは、重みの2乗ノルム(L2ノルム)を損失関数に加えること。これによって重みが大きいと損失関数の値が大きくなるので、重みが大きくなりすぎないようになる。過学習は重みが大きくなることで発生することが多いからこういうことする。L2ノルムは、各次元の値の2乗の和。

サンプル(TensorFlow使ってない)

weight_decay = 0
for idx in range(1, self.hidden_layer_num + 2):
    W = self.params['W' + str(idx)]
    weight_decay += 0.5 * self.weight_decay_lambda * np.sum(W ** 2)
return self.last_layer.forward(y, t) + weight_decay

サンプル(TensorFlow)

def _variable_with_weight_decay(name, shape, stddev, wd):
  """Helper to create an initialized Variable with weight decay.

  Note that the Variable is initialized with a truncated normal distribution.
  A weight decay is added only if one is specified.

  Args:
    name: name of the variable
    shape: list of ints
    stddev: standard deviation of a truncated Gaussian
    wd: add L2Loss weight decay multiplied by this float. If None, weight
        decay is not added for this Variable.

  Returns:
    Variable Tensor
  """
  dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
  var = _variable_on_cpu(
      name,
      shape,
      tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
  if wd is not None:
    weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss')
    tf.add_to_collection('losses', weight_decay)
  return var

上記の、_variable_on_cpu()は、下記。

def _variable_on_cpu(name, shape, initializer):
  """Helper to create a Variable stored on CPU memory.

  Args:
    name: name of the variable
    shape: list of ints
    initializer: initializer for Variable

  Returns:
    Variable Tensor
  """
  with tf.device('/cpu:0'):
    dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
    var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype)
  return var

そして、_variable_with_weight_decay()は下記のように使われる。

kernel = _variable_with_weight_decay('weights',
                                     shape=[5, 5, 3, 64],
                                     stddev=5e-2,
                                     wd=0.0)

(tf.nn.l2_loss()

多分2乗した合計かその半分が出てくるんじゃないかと思うので、確かめてみる。

import tensorflow as tf

wd = tf.constant(0.01)
W = tf.constant([1., 2., 3.])
l2 = tf.nn.l2_loss(W)
weight_decay = tf.mul(l2, wd)

with tf.Session() as sess:
    l2, weight_decay = sess.run([l2, weight_decay])
    print(l2)
    print(weight_decay)

結果

7.0
0.07

2乗和の半分っぽい。

TensorFlowのサンプルのように、wdを0.0にしてるってことは、weight decayに含めないってことか。畳み込み層は含まず、全結合層だけwdを0.004にしてる。

TensorFlowで損失関数にL2ノルムのweight decayを足してるところ

def loss(logits, labels):
  """Add L2Loss to all the trainable variables.

  Add summary for "Loss" and "Loss/avg".
  Args:
    logits: Logits from inference().
    labels: Labels from distorted_inputs or inputs(). 1-D tensor
            of shape [batch_size]

  Returns:
    Loss tensor of type float.
  """
  # Calculate the average cross entropy loss across the batch.
  labels = tf.cast(labels, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=labels, logits=logits, name='cross_entropy_per_example')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  tf.add_to_collection('losses', cross_entropy_mean)

  # The total loss is defined as the cross entropy loss plus all of the weight
  # decay terms (L2 loss).
  return tf.add_n(tf.get_collection('losses'), name='total_loss')

バッチ処理してるので、cross_entropyの平均をとって、weigt decayが入った'losses'というコレクションに入れて、コレクション内の数字を全部足している。

  • Javascript
  • Python
  • PHP
  • Go
  • OS・サーバ
  • 機械学習
  • つくったもの
  • 数学
  • アルゴリズム
  • Logicky