tensorflow Tensor indexing Numpy-like indexing using tensors


Example

This example is based on this post: TensorFlow - numpy-like tensor indexing.

In Numpy you can use arrays to index into an array. E.g. in order to select the elements at (1, 2) and (3, 2) in a 2-dimensional array, you can do this:

# data is [[0, 1, 2, 3, 4, 5],
#          [6, 7, 8, 9, 10, 11],
#          [12 13 14 15 16 17],
#          [18 19 20 21 22 23],
#          [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
a = [1, 3]
b = [2, 2]
selected = data[a, b]
print(selected)

This will print:

[ 8 20]

To get the same behaviour in Tensorflow, you can use tf.gather_nd, which is an extension of tf.gather. The above example can be written like this:

x = tf.constant(data)
idx1 = tf.constant(a)
idx2 = tf.constant(b)
result = tf.gather_nd(x, tf.stack((idx1, idx2), -1))
        
with tf.Session() as sess:
    print(sess.run(result))

This will print:

[ 8 20]

tf.stack is the equivalent of np.asarray and in this case stacks the two index vectors along the last dimension (which in this case is the 1st) to produce:

[[1 2]
 [3 2]]