TensorFlow - tf.train.MonitoredTrainingSession

tf.train.MonitoredTrainingSessionを確認します。 訓練をモニターするのに特化したセッションという感じでしょうか?チュートリアルのコードでは下記のような使われ方をしています。

with tf.train.MonitoredTrainingSession(
        log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():



tf.train.MonitoredTrainingSession(master='', is_chief=True, checkpoint_dir=None, scaffold=None, hooks=None, chief_only_hooks=None, save_checkpoint_secs=600, save_summaries_steps=100, config=None)


For a chief, this utility sets proper session initializer/restorer. It also creates hooks related to checkpoint and summary saving. For workers, this utility sets proper session creator which waits for the chief to inialize/restore.


  • master: String the TensorFlow master to use.
  • is_chief: If True, it will take care of initialization and recovery the underlying TensorFlow session. If False, it will wait on a chief to initialize or recover the TensorFlow session.
  • checkpoint_dir: A string. Optional path to a directory where to restore variables.
  • scaffold: A Scaffold used for gathering or building supportive ops. If not specified, a default one is created. It's used to finalize the graph.
  • hooks: Optional list of SessionRunHook objects.
  • chief_only_hooks: list of SessionRunHook objects. Activate these hooks if is_chief==True, ignore otherwise.
  • save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved using a default checkpoint saver. If save_checkpoint_secs is set to None, then the default checkpoint saver isn't used.
  • save_summaries_steps: The frequency, in number of global steps, that the summaries are written to disk using a default summary saver. If save_summaries_steps is set to None, then the default summary saver isn't used.
  • config: an instance of tf.ConfigProto proto used to configure the session. It's the config argument of constructor of tf.Session.


class tf.train.SessionRunHook class tf.train.StopAtStepHook class tf.train.NanTensorHook


引数でcheckpoint_dirを聞いてるくらいだから、tf.summary.FileWriter('.\logs', sess.graph)みたいのを書かなくてもログが保存されるようになったりしてるのか試してみようと思います。

import tensorflow as tf
from datetime import datetime
import time

logdir = './logs'

a = tf.constant(2)
b = tf.constant(3)
add = tf.add(a, b)

class Hoge(tf.train.SessionRunHook):
    def begin(self):
        self._step = -1

    def before_run(self, run_context):
        self._step += 1
        self._start_time = time.time()
        return tf.train.SessionRunArgs(add)

    def after_run(self, run_context, run_values):
        duration = time.time() - self._start_time
        result = run_values.results
        if self._step % 10 == 0:
            format_str = 'RESULT: {}, STEP:{}, {:%Y-%m-%d %H:%M:%S}, {:.2}'
            print(format_str.format(result, self._step, datetime.now(), duration))

with tf.train.MonitoredTrainingSession(
        hooks=[Hoge()]) as mon_sess:
    for _ in range(50):


RESULT: 5, STEP:0, 2017-01-24 17:01:32, 0.11
RESULT: 5, STEP:10, 2017-01-24 17:01:33, 0.00098
RESULT: 5, STEP:20, 2017-01-24 17:01:33, 0.0
RESULT: 5, STEP:30, 2017-01-24 17:01:33, 0.0
RESULT: 5, STEP:40, 2017-01-24 17:01:33, 0.00098

そして、logsディレクトリにログファイルが登録されていて、TensorBoardでも見られた。 before_runのreturnで、tf.train.SessionRunArgs(add)とやると、after_runのrun_values.resultsに結果が入ってくる。