# tensorflow Tensor indexing How to use tf.gather_nd

## Example

`tf.gather_nd` is an extension of `tf.gather` in the sense that it allows you to not only access the 1st dimension of a tensor, but potentially all of them.

Arguments:

• `params`: a Tensor of rank `P` representing the tensor we want to index into
• `indices`: a Tensor of rank `Q` representing the indices into `params` we want to access

The output of the function depends on the shape of `indices`. If the innermost dimension of `indices` has length `P`, we are collecting single elements from `params`. If it is less than `P`, we are collecting slices, just like with `tf.gather` but without the restriction that we can only access the 1st dimension.

Collecting elements from a tensor of rank 2

To access the element at `(1, 2)` in a matrix, we can use:

``````# data is [[0, 1, 2, 3, 4, 5],
#          [6, 7, 8, 9, 10, 11],
#          [12 13 14 15 16 17],
#          [18 19 20 21 22 23],
#          [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [1, 2])
``````

where `result` will just be `8` as expected. Note how this is different from `tf.gather`: the same indices passed to `tf.gather(x, [1, 2])` would have given as the 2nd and 3rd row from `data`.

If you want to retrieve more than one element at the same time, just pass a list of index pairs:

``````result = tf.gather_nd(x, [[1, 2], [4, 3], [2, 5]])
``````

which will return `[ 8 27 17]`

Collecting rows from a tensor of rank 2

If in the above example you want to collect rows (i.e. slices) instead of elements, adjust the `indices` parameter as follows:

``````data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [, ])
``````

This will give you the 2nd and 4th row of `data`, i.e.

``````[[ 6  7  8  9 10 11]
[18 19 20 21 22 23]]
``````

Collecting elements from a tensor of rank 3

The concept of how to access rank-2 tensors directly translates to higher dimensional tensors. So, to access elements in a rank-3 tensor, the innermost dimension of `indices` must have length 3.

``````# data is [[[ 0  1]
#          [ 2  3]
#          [ 4  5]]
#
#         [[ 6  7]
#          [ 8  9]
#          [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[0, 0, 0], [1, 2, 1]])
``````

`result` will now look like this: `[ 0 11]`

Collecting batched rows from a tensor of rank 3

Let's think of a rank-3 tensor as a batch of matrices shaped `(batch_size, m, n)`. If you want to collect the first and second row for every element in the batch, you could use this:

``````# data is [[[ 0  1]
#          [ 2  3]
#          [ 4  5]]
#
#         [[ 6  7]
#          [ 8  9]
#          [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[[0, 0], [0, 1]], [[1, 0], [1, 1]]])
``````

which will result in this:

``````[[[0 1]
[2 3]]

[[6 7]
[8 9]]]
``````

Note how the shape of `indices` influences the shape of the output tensor. If we would have used a rank-2 tensor for the `indices` argument:

``````result = tf.gather_nd(x, [[0, 0], [0, 1], [1, 0], [1, 1]])
``````

the output would have been

``````[[0 1]
[2 3]
[6 7]
[8 9]]
`````` PDF - Download tensorflow for free