Logicky Blog

Logickyの開発ブログです

TensorFlow - 畳み込み演算の関数 tf.nn.conv2d

TensorFlowには、畳み込み演算用の関数があります。tf.nn.conv2dです。下記のように使います。

tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

xは、元画像です。Wはフィルターのパラメタです。stridesはフィルタの適用間隔です。paddingはフィルタ適用時に画像領域が足りない時どうするかです。SAMEの場合は、足りない場合、元画像に0を足します。

  • xは、(画像の枚数、縦、横、チャンネル数)という配列にします。
  • Wは、(縦、横、チャンネル数、フィルター数)という配列にします。
  • 各チャンネルを合計した配列を返してくれるようです。
  • xとWのチャンネル数が不一致だと怒られます。

とりあえず色々やって試してみます。

import numpy as np
from PIL import Image
import tensorflow as tf

x = np.array([[
    [[1], [2], [3], [4]],
    [[5], [6], [7], [8]],
    [[9], [10], [11], [12]]
]])
print(x.shape)
W = np.array([[
    [[1]],[[1]]
]])
print(W.shape)
x = tf.constant(x, dtype=tf.float32)
W = tf.constant(W, dtype=tf.float32)
conv = tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

with tf.Session() as sess:
    result = sess.run(conv)
    print(result.shape)
    print(result)

これは、 x.shapeが、(1, 3, 4, 1) W.shapeが、(1, 2, 1, 1) tf.nn.conv2dの結果のshapeが、(1, 3, 4, 1) 結果内容は、下記です。

[[[[  3.]
   [  5.]
   [  7.]
   [  4.]]

  [[ 11.]
   [ 13.]
   [ 15.]
   [  8.]]

  [[ 19.]
   [ 21.]
   [ 23.]
   [ 12.]]]]

フィルタの値は2つとも1なので、一番左上だと、1 * 1 + 2 * 1 = 3が入ってます。 一番右上は、4 * 1 + 0 * 1 = 4になります。 tf.nn.conv2dのパラメタのpaddingをVALIDにすると、フィルタ適用結果は下記のようになります。足りないところはちょん切られてしまいます。

(1, 3, 3, 1)
[[[[  3.]
   [  5.]
   [  7.]]

  [[ 11.]
   [ 13.]
   [ 15.]]

  [[ 19.]
   [ 21.]
   [ 23.]]]]

tf.nn.conv2dのstridesは下記のようなルールだそうです。

[1, dy, dx, 1]となっていて、縦方向にdyピクセル毎、横方向にdxピクセル毎にフィルタを適用する。

どうも、dy, dx以外の1は固定値だそうです。。本当ならなぜあるのだろうか。今度調べようかな。

strides=[1, 1, 2, 1]にしてみます。

(1, 3, 2, 1)
[[[[  3.]
   [  7.]]

  [[ 11.]
   [ 15.]]

  [[ 19.]
   [ 23.]]]]

横は1個飛ばしでフィルタ適用するので、横だけ半分になりました。

ではめんどくさいですが、チャンネル数が2の場合についても確認してみようと思います。

x = np.array([[
    [[1, 1], [2, 2], [3, 3], [4, 4]],
    [[5, 5], [6, 6], [7, 7], [8, 8]],
    [[9, 9], [10, 10], [11, 11], [12, 12]]
]])
print(x.shape)
W = np.array([[
    [[1], [1]],
    [[1], [1]]
]])
print(W.shape)
x = tf.constant(x, dtype=tf.float32)
W = tf.constant(W, dtype=tf.float32)
conv = tf.nn.conv2d(x, W, strides=[1, 1, 2, 1], padding='SAME')

with tf.Session() as sess:
    result = sess.run(conv)
    print(result.shape)
    print(result)

これは、 xのshapeが、(1, 3, 4, 2) Wのshapeが、(1, 2, 2, 1)

結果は下記です。

(1, 3, 2, 1)
[[[[  6.]
   [ 14.]]

  [[ 22.]
   [ 30.]]

  [[ 38.]
   [ 46.]]]]

おーやっぱり完全に足されてました。