Logicky BLOG

Logickyの開発ブログです

  • Javascript
  • Python
  • PHP
  • Go
  • OS・サーバ
  • 機械学習
  • つくったもの
  • 数学
  • アルゴリズム
  • Logicky

TensorFlow - Readerクラスでバッチ処理

参考: TensorFlow : How To : データを読む Inputs and Readers TensorFlowチュートリアル - 畳み込みニューラルネットワーク(翻訳)

tf.train.shuffle_batchというのを使う。シャッフルが不要な時は、tf.train.batchを使う。

tf.train.shuffle_batch

tf.train.shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, num_threads=1, seed=None, enqueue_many=False, shapes=None, allow_smaller_final_batch=False, shared_name=None, name=None)

# min_after_dequeue はバッファ、そこからランダムにサンプリングします、がどのくらい大きいかを定義します # - 大きければより良いシャッフリング、しかし開始が遅くなりメモリがより多く使用されることを意味します。 # capacity は min_after_dequeue よりも大きくなければなりません。 # そして the amount larger は事前読み込みする最大値を決めます。 # 推奨: # min_after_dequeue + (num_threads + 小さい安全マージン) * batch_size min_after_dequeue = 10000 capacity = min_after_dequeue + 3 * batch_size example_batch, label_batch = tf.train.shuffle_batch( [example, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue) return example_batch, label_batch

TensorFlowのcifar10のチュートリアルのコード(cifar10_input.py)に下記のように使われている。

# Create a queue that shuffles the examples, and then
# read 'batch_size' images + labels from the example queue.
num_preprocess_threads = 16
if shuffle:
images, label_batch = tf.train.shuffle_batch(
    [image, label],
    batch_size=batch_size,
    num_threads=num_preprocess_threads,
    capacity=min_queue_examples + 3 * batch_size,
    min_after_dequeue=min_queue_examples)
else:
images, label_batch = tf.train.batch(
    [image, label],
    batch_size=batch_size,
    num_threads=num_preprocess_threads,
    capacity=min_queue_examples + 3 * batch_size)

上記コードのmin_after_dequeue(min_queue_examples)は、下記のように取得されている。

# Ensure that the random shuffling has good mixing properties.
min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
                       min_fraction_of_examples_in_queue)
print ('Filling queue with %d CIFAR images before starting to train. '
     'This will take a few minutes.' % min_queue_examples)

上記のNUM_EXAMPLES_PER_EPOCH_FOR_TRAINは下記のように設定されている。

NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000

バッチもシャッフルしてくれる。シャッフルするために、事前にデータを何個か読み込んでおいて、その中からシャッフルしてくれる。事前読み込み数がmin_after_dequeueで、これが多いと沢山の中からシャッフルできるからシャッフルされ具合がいい感じになる。でも読み込み数が多いからその分遅くなる。でも途中からは裏で読み込んでおいてくれるので遅くならない。ただメモリ使用量は増える。って感じすか??上記事例だとエポック当たり5万件のデータがある場合、それに0.4をかけた数をmin_after_dequeueにしている。そして推奨値の数式に従って、capacityは、min_after_dequeue + 3 * バッチサイズにしている。

バッチ処理してみる

コードサンプル 下記のfile0.csvとfile1.csvはここで使ったものと同じ内容。

import tensorflow as tf

filenames = ["./hoge/file0.csv", "./hoge/file1.csv"]
num_threads = 2
batch_size = 3
num_per_epoch = 16

filename_queue = tf.train.string_input_producer(filenames)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults)
features = tf.pack([col1, col2, col3, col4])
label = col5

min_after_dequese = int(num_per_epoch * 0.4)
features, label_batch = tf.train.shuffle_batch(
    [features, label],
    batch_size=batch_size,
    num_threads=num_threads,
    capacity=min_after_dequese + 3 * batch_size,
    min_after_dequeue=min_after_dequese)

labels = tf.reshape(label_batch, [batch_size])

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(num_per_epoch):
        x, t = sess.run([features, labels])
        print(x)
        print(t)

    coord.request_stop()
    coord.join(threads)

結果

[[600 601 602 603]
 [300 301 302 303]
 [100 101 102 103]]
[16 13 11]
[[200 201 202 203]
 [300 301 302 303]
 [400 401 402 403]]
[12  3 14]
[[800 801 802 803]
 [500 501 502 503]
 [600 601 602 603]]
[18  5  6]
[[700 701 702 703]
 [400 401 402 403]
 [200 201 202 203]]
[17  4  2]
[[100 101 102 103]
 [400 401 402 403]
 [800 801 802 803]]
[1 4 8]
[[100 101 102 103]
 [400 401 402 403]
 [700 701 702 703]]
[ 1 14  7]
[[200 201 202 203]
 [500 501 502 503]
 [600 601 602 603]]
[12  5  6]
[[400 401 402 403]
 [200 201 202 203]
 [300 301 302 303]]
[ 4  2 13]
[[300 301 302 303]
 [600 601 602 603]
 [700 701 702 703]]
[3 6 7]
[[800 801 802 803]
 [500 501 502 503]
 [100 101 102 103]]
[ 8 15 11]
[[100 101 102 103]
 [300 301 302 303]
 [600 601 602 603]]
[ 1  3 16]
[[500 501 502 503]
 [100 101 102 103]
 [700 701 702 703]]
[15 11 17]
[[200 201 202 203]
 [800 801 802 803]
 [600 601 602 603]]
[ 2 18 16]
[[700 701 702 703]
 [400 401 402 403]
 [300 301 302 303]]
[17 14 13]
[[100 101 102 103]
 [500 501 502 503]
 [100 101 102 103]]
[11  5  1]
[[200 201 202 203]
 [400 401 402 403]
 [800 801 802 803]]
[12 14  8]
  • Javascript
  • Python
  • PHP
  • Go
  • OS・サーバ
  • 機械学習
  • つくったもの
  • 数学
  • アルゴリズム
  • Logicky