AI

TensorFlow - 畳み込み演算の関数 tf.nn.conv2d

TensorFlowには、畳み込み演算用の関数があります。tf.nn.conv2dです。下記のように使います。

tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

xは、元画像です。Wはフィルターのパラメタです。stridesはフィルタの適用間隔です。paddingはフィルタ適用時に画像領域が足りない時どうするかです。SAMEの場合は、足りない場合、元画像に0を足します。

  • xは、(画像の枚数、縦、横、チャンネル数)という配列にします。
  • Wは、(縦、横、チャンネル数、フィルター数)という配列にします。
  • 各チャンネルを合計した配列を返してくれるようです。
  • xとWのチャンネル数が不一致だと怒られます。

とりあえず色々やって試してみます。

import numpy as np
from PIL import Image
import tensorflow as tf
x = np.array([[
[[1], [2], [3], [4]],
[[5], [6], [7], [8]],
[[9], [10], [11], [12]]
]])
print(x.shape)
W = np.array([[
[[1]],[[1]]
]])
print(W.shape)
x = tf.constant(x, dtype=tf.float32)
W = tf.constant(W, dtype=tf.float32)
conv = tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
with tf.Session() as sess:
result = sess.run(conv)
print(result.shape)
print(result)

これは、 x.shapeが、(1, 3, 4, 1) W.shapeが、(1, 2, 1, 1) tf.nn.conv2dの結果のshapeが、(1, 3, 4, 1) 結果内容は、下記です。

[[[[ 3.]
[ 5.]
[ 7.]
[ 4.]]
[[ 11.]
[ 13.]
[ 15.]
[ 8.]]
[[ 19.]
[ 21.]
[ 23.]
[ 12.]]]]

フィルタの値は2つとも1なので、一番左上だと、1 * 1 + 2 * 1 = 3が入ってます。 一番右上は、4 * 1 + 0 * 1 = 4になります。 tf.nn.conv2dのパラメタのpaddingをVALIDにすると、フィルタ適用結果は下記のようになります。足りないところはちょん切られてしまいます。

(1, 3, 3, 1)
[[[[ 3.]
[ 5.]
[ 7.]]
[[ 11.]
[ 13.]
[ 15.]]
[[ 19.]
[ 21.]
[ 23.]]]]

tf.nn.conv2dのstridesは下記のようなルールだそうです。

[1, dy, dx, 1]となっていて、縦方向にdyピクセル毎、横方向にdxピクセル毎にフィルタを適用する。

どうも、dy, dx以外の1は固定値だそうです。。本当ならなぜあるのだろうか。今度調べようかな。

strides=[1, 1, 2, 1]にしてみます。

(1, 3, 2, 1)
[[[[ 3.]
[ 7.]]
[[ 11.]
[ 15.]]
[[ 19.]
[ 23.]]]]

横は1個飛ばしでフィルタ適用するので、横だけ半分になりました。

ではめんどくさいですが、チャンネル数が2の場合についても確認してみようと思います。

x = np.array([[
[[1, 1], [2, 2], [3, 3], [4, 4]],
[[5, 5], [6, 6], [7, 7], [8, 8]],
[[9, 9], [10, 10], [11, 11], [12, 12]]
]])
print(x.shape)
W = np.array([[
[[1], [1]],
[[1], [1]]
]])
print(W.shape)
x = tf.constant(x, dtype=tf.float32)
W = tf.constant(W, dtype=tf.float32)
conv = tf.nn.conv2d(x, W, strides=[1, 1, 2, 1], padding='SAME')
with tf.Session() as sess:
result = sess.run(conv)
print(result.shape)
print(result)

これは、 xのshapeが、(1, 3, 4, 2) Wのshapeが、(1, 2, 2, 1)

結果は下記です。

(1, 3, 2, 1)
[[[[ 6.]
[ 14.]]
[[ 22.]
[ 30.]]
[[ 38.]
[ 46.]]]]

おーやっぱり完全に足されてました。