tensorflow Tensor indexing How to use tf.gather_nd


Example

tf.gather_nd is an extension of tf.gather in the sense that it allows you to not only access the 1st dimension of a tensor, but potentially all of them.

Arguments:

  • params: a Tensor of rank P representing the tensor we want to index into
  • indices: a Tensor of rank Q representing the indices into params we want to access

The output of the function depends on the shape of indices. If the innermost dimension of indices has length P, we are collecting single elements from params. If it is less than P, we are collecting slices, just like with tf.gather but without the restriction that we can only access the 1st dimension.


Collecting elements from a tensor of rank 2

To access the element at (1, 2) in a matrix, we can use:

# data is [[0, 1, 2, 3, 4, 5],
#          [6, 7, 8, 9, 10, 11],
#          [12 13 14 15 16 17],
#          [18 19 20 21 22 23],
#          [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [1, 2])

where result will just be 8 as expected. Note how this is different from tf.gather: the same indices passed to tf.gather(x, [1, 2]) would have given as the 2nd and 3rd row from data.

If you want to retrieve more than one element at the same time, just pass a list of index pairs:

result = tf.gather_nd(x, [[1, 2], [4, 3], [2, 5]])

which will return [ 8 27 17]


Collecting rows from a tensor of rank 2

If in the above example you want to collect rows (i.e. slices) instead of elements, adjust the indices parameter as follows:

data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [[1], [3]])

This will give you the 2nd and 4th row of data, i.e.

[[ 6  7  8  9 10 11]
 [18 19 20 21 22 23]]

Collecting elements from a tensor of rank 3

The concept of how to access rank-2 tensors directly translates to higher dimensional tensors. So, to access elements in a rank-3 tensor, the innermost dimension of indices must have length 3.

# data is [[[ 0  1]
#          [ 2  3]
#          [ 4  5]]
#
#         [[ 6  7]
#          [ 8  9]
#          [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[0, 0, 0], [1, 2, 1]])

result will now look like this: [ 0 11]


Collecting batched rows from a tensor of rank 3

Let's think of a rank-3 tensor as a batch of matrices shaped (batch_size, m, n). If you want to collect the first and second row for every element in the batch, you could use this:

# data is [[[ 0  1]
#          [ 2  3]
#          [ 4  5]]
#
#         [[ 6  7]
#          [ 8  9]
#          [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[[0, 0], [0, 1]], [[1, 0], [1, 1]]])

which will result in this:

[[[0 1]
  [2 3]]

 [[6 7]
  [8 9]]]

Note how the shape of indices influences the shape of the output tensor. If we would have used a rank-2 tensor for the indices argument:

result = tf.gather_nd(x, [[0, 0], [0, 1], [1, 0], [1, 1]])

the output would have been

[[0 1]
 [2 3]
 [6 7]
 [8 9]]