`tf.gather_nd`

is an extension of `tf.gather`

in the sense that it allows you to not only access the 1st dimension of a tensor, but potentially all of them.

Arguments:

`params`

: a Tensor of rank`P`

representing the tensor we want to index into`indices`

: a Tensor of rank`Q`

representing the indices into`params`

we want to access

The output of the function depends on the shape of `indices`

. If the innermost dimension of `indices`

has length `P`

, we are collecting single elements from `params`

. If it is less than `P`

, we are collecting slices, just like with `tf.gather`

but without the restriction that we can only access the 1st dimension.

**Collecting elements from a tensor of rank 2**

To access the element at `(1, 2)`

in a matrix, we can use:

```
# data is [[0, 1, 2, 3, 4, 5],
# [6, 7, 8, 9, 10, 11],
# [12 13 14 15 16 17],
# [18 19 20 21 22 23],
# [24, 25, 26, 27, 28, 29]]
data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [1, 2])
```

where `result`

will just be `8`

as expected. Note how this is different from `tf.gather`

: the same indices passed to `tf.gather(x, [1, 2])`

would have given as the 2nd and 3rd *row* from `data`

.

If you want to retrieve more than one element at the same time, just pass a list of index pairs:

```
result = tf.gather_nd(x, [[1, 2], [4, 3], [2, 5]])
```

which will return `[ 8 27 17]`

**Collecting rows from a tensor of rank 2**

If in the above example you want to collect rows (i.e. slices) instead of elements, adjust the `indices`

parameter as follows:

```
data = np.reshape(np.arange(30), [5, 6])
x = tf.constant(data)
result = tf.gather_nd(x, [[1], [3]])
```

This will give you the 2nd and 4th row of `data`

, i.e.

```
[[ 6 7 8 9 10 11]
[18 19 20 21 22 23]]
```

**Collecting elements from a tensor of rank 3**

The concept of how to access rank-2 tensors directly translates to higher dimensional tensors. So, to access elements in a rank-3 tensor, the innermost dimension of `indices`

must have length 3.

```
# data is [[[ 0 1]
# [ 2 3]
# [ 4 5]]
#
# [[ 6 7]
# [ 8 9]
# [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[0, 0, 0], [1, 2, 1]])
```

`result`

will now look like this: `[ 0 11]`

**Collecting batched rows from a tensor of rank 3**

Let's think of a rank-3 tensor as a batch of matrices shaped `(batch_size, m, n)`

. If you want to collect the first and second row for every element in the batch, you could use this:

```
# data is [[[ 0 1]
# [ 2 3]
# [ 4 5]]
#
# [[ 6 7]
# [ 8 9]
# [10 11]]]
data = np.reshape(np.arange(12), [2, 3, 2])
x = tf.constant(data)
result = tf.gather_nd(x, [[[0, 0], [0, 1]], [[1, 0], [1, 1]]])
```

which will result in this:

```
[[[0 1]
[2 3]]
[[6 7]
[8 9]]]
```

Note how the shape of `indices`

influences the shape of the output tensor. If we would have used a rank-2 tensor for the `indices`

argument:

```
result = tf.gather_nd(x, [[0, 0], [0, 1], [1, 0], [1, 1]])
```

the output would have been

```
[[0 1]
[2 3]
[6 7]
[8 9]]
```

This modified text is an extract of the original Stack Overflow Documentation created by following contributors and released under CC BY-SA 3.0

This website is not affiliated with Stack Overflow