tf.strided_sliceを調べます。TensorFlowのGithubにのってる説明ページはこれです。
tf.strided_slice(input_, begin, end, strides=None, begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0, var=None, name=None) {#strided_slice}
実験してみる
コード
import tensorflow as tf tensor = tf.constant([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) result = tf.strided_slice(tensor, [0], [9], [2]) with tf.Session() as sess: result = sess.run(result) print(result)
結果
[1 3 5 7 9]
とりあえず、inputは元データ、beginは開始位置、endは終了位置、stridesは間隔のようです。以前はstridesを指定しない場合、デフォルトでstridesを1とみなしていたようですが、最近のtensorflowのアップデートで、stridesも明示しないとエラーになるようになったようです。ちなみに、終了位置は普通の配列のスライスと同じで、指定したインデックスのひとつ前までになります。
これだけなら簡単なのですが、他にも色々引数があるし、inputもbeginも多次元配列に対応しているようです。
コード
import tensorflow as tf tensor = tf.constant([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]]) result = tf.strided_slice(tensor, [0, 5], [2, 8], [1, 1]) with tf.Session() as sess: result = sess.run(result) print(result)
結果
[[ 6 7 8] [16 17 18]]
begin、end、stridesはnumpyのshapeのような感じで入れていくようです。コードだとbeginは[0, 5]ですが、これは1次元目は0から始まり、2次元目は5から始まるということになるようです。