tf.gather_nd
is an extension of tf.gather
in the sense that it allows you to not only access the 1st dimension of a tensor, but potentially all of them.
Arguments:
params
: a Tensor of rank P
representing the tensor we want to index intoindices
: a Tensor of rank Q
representing the indices into params
we want to accessThe output of the function depends on the shape of indices
. If the innermost dimension of indices
has length P
, we are collecting single elements from params
. If it is less than P
, we are collecting slices, just like with tf.gather
but without the restriction that we can only access the 1st dimension.
Collecting elements from a tensor of rank 2
To access the element at (1, 2)
in a matrix, we can use:
# data is [[0, 1, 2, 3, 4, 5],
# [6, 7, 8, 9, 10, 11],
# [12 13 14 15 16 17],
# [18 19 20 21 22 23],
# [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [1, 2])
where result
will just be 8
as expected. Note how this is different from tf.gather
: the same indices passed to tf.gather(x, [1, 2])
would have given as the 2nd and 3rd row from data
.
If you want to retrieve more than one element at the same time, just pass a list of index pairs:
result = tf.gather_nd(x, [[1, 2], [4, 3], [2, 5]])
which will return [ 8 27 17]
Collecting rows from a tensor of rank 2
If in the above example you want to collect rows (i.e. slices) instead of elements, adjust the indices
parameter as follows:
data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [[1], [3]])
This will give you the 2nd and 4th row of data
, i.e.
[[ 6 7 8 9 10 11]
[18 19 20 21 22 23]]
Collecting elements from a tensor of rank 3
The concept of how to access rank-2 tensors directly translates to higher dimensional tensors. So, to access elements in a rank-3 tensor, the innermost dimension of indices
must have length 3.
# data is [[[ 0 1]
# [ 2 3]
# [ 4 5]]
#
# [[ 6 7]
# [ 8 9]
# [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[0, 0, 0], [1, 2, 1]])
result
will now look like this: [ 0 11]
Collecting batched rows from a tensor of rank 3
Let's think of a rank-3 tensor as a batch of matrices shaped (batch_size, m, n)
. If you want to collect the first and second row for every element in the batch, you could use this:
# data is [[[ 0 1]
# [ 2 3]
# [ 4 5]]
#
# [[ 6 7]
# [ 8 9]
# [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[[0, 0], [0, 1]], [[1, 0], [1, 1]]])
which will result in this:
[[[0 1]
[2 3]]
[[6 7]
[8 9]]]
Note how the shape of indices
influences the shape of the output tensor. If we would have used a rank-2 tensor for the indices
argument:
result = tf.gather_nd(x, [[0, 0], [0, 1], [1, 0], [1, 1]])
the output would have been
[[0 1]
[2 3]
[6 7]
[8 9]]