Logicky Blog

Logickyの開発ブログです

TensorFlow - tf.train.MonitoredTrainingSession

tf.train.MonitoredTrainingSessionを確認します。 訓練をモニターするのに特化したセッションという感じでしょうか?チュートリアルのコードでは下記のような使われ方をしています。

with tf.train.MonitoredTrainingSession(
    checkpoint_dir=FLAGS.train_dir,
    hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
           tf.train.NanTensorHook(loss),
           _LoggerHook()],
    config=tf.ConfigProto(
        log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
    mon_sess.run(train_op)

普通のセッションを使う代わりに使っています。

https://www.tensorflow.org/api_docs/python/train/distributed_execution#MonitoredTrainingSession

tf.train.MonitoredTrainingSession(master='', is_chief=True, checkpoint_dir=None, scaffold=None, hooks=None, chief_only_hooks=None, save_checkpoint_secs=600, save_summaries_steps=100, config=None)

引数的に、勝手に変数の保存をしてくれたりするようです。

For a chief, this utility sets proper session initializer/restorer. It also creates hooks related to checkpoint and summary saving. For workers, this utility sets proper session creator which waits for the chief to inialize/restore.

引数は下記です。

  • master: String the TensorFlow master to use.
  • is_chief: If True, it will take care of initialization and recovery the underlying TensorFlow session. If False, it will wait on a chief to initialize or recover the TensorFlow session.
  • checkpoint_dir: A string. Optional path to a directory where to restore variables.
  • scaffold: A Scaffold used for gathering or building supportive ops. If not specified, a default one is created. It's used to finalize the graph.
  • hooks: Optional list of SessionRunHook objects.
  • chief_only_hooks: list of SessionRunHook objects. Activate these hooks if is_chief==True, ignore otherwise.
  • save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved using a default checkpoint saver. If save_checkpoint_secs is set to None, then the default checkpoint saver isn't used.
  • save_summaries_steps: The frequency, in number of global steps, that the summaries are written to disk using a default summary saver. If save_summaries_steps is set to None, then the default summary saver isn't used.
  • config: an instance of tf.ConfigProto proto used to configure the session. It's the config argument of constructor of tf.Session.

hookというのはコールバック的な感じで、session.runの前後に実行できるクラスらしい。これを紐づけることができて、リストで複数のhookを登録することができる。カスタマイズできるのが、SessionRunHookで、それ以外に用途が決まっているhookが複数事前に提供されているといったような感じのイメージを持った気がする。

class tf.train.SessionRunHook class tf.train.StopAtStepHook class tf.train.NanTensorHook

configは設定で、log_device_placementは、手動でデバイス指定してる場合にTrueにすると、手動設定したデバイスを選んでくれる的なやつっぽい。自分の非力な1体のPC環境ではあまり関係がないと思う。非力な1体のPC環境の場合、hookは自作する_LoggerHook()だけで、configは未設定でもよさそう。

引数でcheckpoint_dirを聞いてるくらいだから、tf.summary.FileWriter('.\logs', sess.graph)みたいのを書かなくてもログが保存されるようになったりしてるのか試してみようと思います。

import tensorflow as tf
from datetime import datetime
import time

logdir = './logs'

a = tf.constant(2)
b = tf.constant(3)
add = tf.add(a, b)

class Hoge(tf.train.SessionRunHook):
    def begin(self):
        self._step = -1

    def before_run(self, run_context):
        self._step += 1
        self._start_time = time.time()
        return tf.train.SessionRunArgs(add)

    def after_run(self, run_context, run_values):
        duration = time.time() - self._start_time
        result = run_values.results
        if self._step % 10 == 0:
            format_str = 'RESULT: {}, STEP:{}, {:%Y-%m-%d %H:%M:%S}, {:.2}'
            print(format_str.format(result, self._step, datetime.now(), duration))

tf.contrib.framework.get_or_create_global_step()
with tf.train.MonitoredTrainingSession(
        checkpoint_dir=logdir,
        hooks=[Hoge()]) as mon_sess:
    for _ in range(50):
        mon_sess.run(add)

結果

RESULT: 5, STEP:0, 2017-01-24 17:01:32, 0.11
RESULT: 5, STEP:10, 2017-01-24 17:01:33, 0.00098
RESULT: 5, STEP:20, 2017-01-24 17:01:33, 0.0
RESULT: 5, STEP:30, 2017-01-24 17:01:33, 0.0
RESULT: 5, STEP:40, 2017-01-24 17:01:33, 0.00098

そして、logsディレクトリにログファイルが登録されていて、TensorBoardでも見られた。 before_runのreturnで、tf.train.SessionRunArgs(add)とやると、after_runのrun_values.resultsに結果が入ってくる。