TensorFlowには、畳み込み演算用の関数があります。tf.nn.conv2dです。下記のように使います。
tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
xは、元画像です。Wはフィルターのパラメタです。stridesはフィルタの適用間隔です。paddingはフィルタ適用時に画像領域が足りない時どうするかです。SAMEの場合は、足りない場合、元画像に0を足します。
- xは、(画像の枚数、縦、横、チャンネル数)という配列にします。
- Wは、(縦、横、チャンネル数、フィルター数)という配列にします。
- 各チャンネルを合計した配列を返してくれるようです。
- xとWのチャンネル数が不一致だと怒られます。
とりあえず色々やって試してみます。
import numpy as np from PIL import Image import tensorflow as tf x = np.array([[ [[1], [2], [3], [4]], [[5], [6], [7], [8]], [[9], [10], [11], [12]] ]]) print(x.shape) W = np.array([[ [[1]],[[1]] ]]) print(W.shape) x = tf.constant(x, dtype=tf.float32) W = tf.constant(W, dtype=tf.float32) conv = tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') with tf.Session() as sess: result = sess.run(conv) print(result.shape) print(result)
これは、 x.shapeが、(1, 3, 4, 1) W.shapeが、(1, 2, 1, 1) tf.nn.conv2dの結果のshapeが、(1, 3, 4, 1) 結果内容は、下記です。
[[[[ 3.] [ 5.] [ 7.] [ 4.]] [[ 11.] [ 13.] [ 15.] [ 8.]] [[ 19.] [ 21.] [ 23.] [ 12.]]]]
フィルタの値は2つとも1なので、一番左上だと、1 * 1 + 2 * 1 = 3が入ってます。 一番右上は、4 * 1 + 0 * 1 = 4になります。 tf.nn.conv2dのパラメタのpaddingをVALIDにすると、フィルタ適用結果は下記のようになります。足りないところはちょん切られてしまいます。
(1, 3, 3, 1) [[[[ 3.] [ 5.] [ 7.]] [[ 11.] [ 13.] [ 15.]] [[ 19.] [ 21.] [ 23.]]]]
tf.nn.conv2dのstridesは下記のようなルールだそうです。
[1, dy, dx, 1]となっていて、縦方向にdyピクセル毎、横方向にdxピクセル毎にフィルタを適用する。
どうも、dy, dx以外の1は固定値だそうです。。本当ならなぜあるのだろうか。今度調べようかな。
strides=[1, 1, 2, 1]にしてみます。
(1, 3, 2, 1) [[[[ 3.] [ 7.]] [[ 11.] [ 15.]] [[ 19.] [ 23.]]]]
横は1個飛ばしでフィルタ適用するので、横だけ半分になりました。
ではめんどくさいですが、チャンネル数が2の場合についても確認してみようと思います。
x = np.array([[ [[1, 1], [2, 2], [3, 3], [4, 4]], [[5, 5], [6, 6], [7, 7], [8, 8]], [[9, 9], [10, 10], [11, 11], [12, 12]] ]]) print(x.shape) W = np.array([[ [[1], [1]], [[1], [1]] ]]) print(W.shape) x = tf.constant(x, dtype=tf.float32) W = tf.constant(W, dtype=tf.float32) conv = tf.nn.conv2d(x, W, strides=[1, 1, 2, 1], padding='SAME') with tf.Session() as sess: result = sess.run(conv) print(result.shape) print(result)
これは、 xのshapeが、(1, 3, 4, 2) Wのshapeが、(1, 2, 2, 1)
結果は下記です。
(1, 3, 2, 1) [[[[ 6.] [ 14.]] [[ 22.] [ 30.]] [[ 38.] [ 46.]]]]
おーやっぱり完全に足されてました。