TensorFlow - tf.train.MonitoredTrainingSession
tf.train.MonitoredTrainingSessionを確認します。 訓練をモニターするのに特化したセッションという感じでしょうか?チュートリアルのコードでは下記のような使われ方をしています。
with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook()], config=tf.ConfigProto( log_device_placement=FLAGS.log_device_placement)) as mon_sess:while not mon_sess.should_stop(): mon_sess.run(train_op)普通のセッションを使う代わりに使っています。
https://www.tensorflow.org/api_docs/python/train/distributed_execution#MonitoredTrainingSession
tf.train.MonitoredTrainingSession(master=”, is_chief=True, checkpoint_dir=None, scaffold=None, hooks=None, chief_only_hooks=None, save_checkpoint_secs=600, save_summaries_steps=100, config=None)
引数的に、勝手に変数の保存をしてくれたりするようです。
For a chief, this utility sets proper session initializer/restorer. It also creates hooks related to checkpoint and summary saving. For workers, this utility sets proper session creator which waits for the chief to inialize/restore.
引数は下記です。
master:Stringthe TensorFlow master to use.is_chief: IfTrue, it will take care of initialization and recovery the underlying TensorFlow session. IfFalse, it will wait on a chief to initialize or recover the TensorFlow session.checkpoint_dir: A string. Optional path to a directory where to restore variables.scaffold: AScaffoldused for gathering or building supportive ops. If not specified, a default one is created. It’s used to finalize the graph.hooks: Optional list ofSessionRunHookobjects.chief_only_hooks: list ofSessionRunHookobjects. Activate these hooks ifis_chief==True, ignore otherwise.save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved using a default checkpoint saver. Ifsave_checkpoint_secsis set toNone, then the default checkpoint saver isn’t used.save_summaries_steps: The frequency, in number of global steps, that the summaries are written to disk using a default summary saver. Ifsave_summaries_stepsis set toNone, then the default summary saver isn’t used.config: an instance oftf.ConfigProtoproto used to configure the session. It’s theconfigargument of constructor oftf.Session.
hookというのはコールバック的な感じで、session.runの前後に実行できるクラスらしい。これを紐づけることができて、リストで複数のhookを登録することができる。カスタマイズできるのが、SessionRunHookで、それ以外に用途が決まっているhookが複数事前に提供されているといったような感じのイメージを持った気がする。
class tf.train.SessionRunHook class tf.train.StopAtStepHook class tf.train.NanTensorHook
configは設定で、log_device_placementは、手動でデバイス指定してる場合にTrueにすると、手動設定したデバイスを選んでくれる的なやつっぽい。自分の非力な1体のPC環境ではあまり関係がないと思う。非力な1体のPC環境の場合、hookは自作する_LoggerHook()だけで、configは未設定でもよさそう。
引数でcheckpoint_dirを聞いてるくらいだから、tf.summary.FileWriter(’.\logs’, sess.graph)みたいのを書かなくてもログが保存されるようになったりしてるのか試してみようと思います。
import tensorflow as tffrom datetime import datetimeimport time
logdir = './logs'
a = tf.constant(2)b = tf.constant(3)add = tf.add(a, b)
class Hoge(tf.train.SessionRunHook): def begin(self): self._step = -1
def before_run(self, run_context): self._step += 1 self._start_time = time.time() return tf.train.SessionRunArgs(add)
def after_run(self, run_context, run_values): duration = time.time() - self._start_time result = run_values.results if self._step % 10 == 0: format_str = 'RESULT: {}, STEP:{}, {:%Y-%m-%d %H:%M:%S}, {:.2}' print(format_str.format(result, self._step, datetime.now(), duration))
tf.contrib.framework.get_or_create_global_step()with tf.train.MonitoredTrainingSession( checkpoint_dir=logdir, hooks=[Hoge()]) as mon_sess: for _ in range(50): mon_sess.run(add)結果
RESULT: 5, STEP:0, 2017-01-24 17:01:32, 0.11RESULT: 5, STEP:10, 2017-01-24 17:01:33, 0.00098RESULT: 5, STEP:20, 2017-01-24 17:01:33, 0.0RESULT: 5, STEP:30, 2017-01-24 17:01:33, 0.0RESULT: 5, STEP:40, 2017-01-24 17:01:33, 0.00098そして、logsディレクトリにログファイルが登録されていて、TensorBoardでも見られた。 before_runのreturnで、tf.train.SessionRunArgs(add)とやると、after_runのrun_values.resultsに結果が入ってくる。