cuda Single-block parallel reduction for non-commutative operator


Example

Doing parallel reduction for a non-commutative operator is a bit more involved, compared to commutative version. In the example we still use a addition over integers for the simplicity sake. It could be replaced, for example, with matrix multiplication which really is non-commutative. Note, when doing so, 0 should be replaced by a neutral element of the multiplication - i.e. an identity matrix.

static const int arraySize = 1000000;
static const int blockSize = 1024;

__global__ void sumNoncommSingleBlock(const int *gArr, int *out) {
    int thIdx = threadIdx.x;
    __shared__ int shArr[blockSize*2];
    __shared__ int offset;
    shArr[thIdx] = thIdx<arraySize ? gArr[thIdx] : 0;
    if (thIdx == 0)
        offset = blockSize;
    __syncthreads();
    while (offset < arraySize) { //uniform
        shArr[thIdx + blockSize] = thIdx+offset<arraySize ? gArr[thIdx+offset] : 0;
        __syncthreads();
        if (thIdx == 0)
            offset += blockSize;
        int sum = shArr[2*thIdx] + shArr[2*thIdx+1];
        __syncthreads();
        shArr[thIdx] = sum;
    }
    __syncthreads();
    for (int stride = 1; stride<blockSize; stride*=2) { //uniform
        int arrIdx = thIdx*stride*2;
        if (arrIdx+stride<blockSize)
            shArr[arrIdx] += shArr[arrIdx+stride];
        __syncthreads();
    }   
    if (thIdx == 0)
        *out = shArr[0];
}

...

sumNoncommSingleBlock<<<1, blockSize>>>(dev_a, dev_out);

In the first while loop executes as long as there are more input elements than threads. In each iteration, a single reduction is performed and the result is compressed into the first half of the shArr array. The second half is then filled with new data.

Once all data is loaded from gArr, the second loop executes. Now, we no longer compress the result (which costs an extra __syncthreads()). In each step the thread n access the 2*n-th active element and adds it up with 2*n+1-th element:

There are many ways to further optimize this simple example, e.g. through warp-level reduction and by removing shared memory bank conflicts.