@@ -29,7 +29,7 @@ __global__
2929void count_kernel (T* arr, int size, int *count, CountIfOp count_if_op) {
3030 int tid = threadIdx .x + blockIdx .x * blockDim .x ;
3131
32- __shared__ unsigned int local_count_array[TPB];
32+ __shared__ int local_count_array[TPB];
3333
3434 if (tid < size) {
3535 if (count_if_op (arr[tid])) {
@@ -44,11 +44,11 @@ void count_kernel(T* arr, int size, int *count, CountIfOp count_if_op) {
4444 for (int offset = blockDim .x / 2 ; offset > 0 ; offset >>=1 ) {
4545 if (threadIdx .x < offset && tid + offset < size) {
4646 local_count_array[threadIdx .x ] += local_count_array[threadIdx .x + offset];
47- __syncthreads ();
4847 }
48+ __syncthreads ();
4949 }
5050
51- if (threadIdx .x == 0 ) {
51+ if (threadIdx .x == 0 ) {
5252 atomicAdd (count, local_count_array[threadIdx .x ]);
5353 }
5454 }
@@ -69,7 +69,7 @@ void count_kernel(T* arr, int size, int *count, CountIfOp count_if_op) {
6969
7070 int block_count = __syncthreads_count (predicate);
7171
72- if (threadIdx .x == 0 ) {
72+ if (threadIdx .x == 0 ) {
7373 atomicAdd (count, block_count);
7474 }
7575
@@ -93,8 +93,34 @@ void count_kernel(T* arr, int size, int *count, CountIfOp count_if_op) {
9393 unsigned ballot_mask = __ballot_sync (FULL_MASK, predicate);
9494 int warp_count = __popc (ballot_mask);
9595
96- if (threadIdx .x % 32 == 0 ) {
97- atomicAdd (count, warp_count);
96+ // global atomics
97+ // if(threadIdx.x == 0) {
98+ // atomicAdd(count, warp_count);
99+ // }
100+
101+
102+ // optimization for block reduction
103+ __shared__ int block_counts[TPB / 32 ];
104+
105+ if (tid < size) {
106+ int warp_id = threadIdx .x / 32 ;
107+ int lane_id = threadIdx .x % 32 ;
108+ if (lane_id == 0 ) {
109+ block_counts[warp_id] = warp_count;
110+ }
111+
112+ __syncthreads ();
113+
114+ for (int offset = (TPB / 32 ) / 2 ; offset > 0 ; offset >>= 1 ) {
115+ if (lane_id == 0 && warp_id < offset && tid + offset < size) {
116+ block_counts[warp_id] += block_counts[warp_id + offset];
117+ }
118+ __syncthreads ();
119+ }
120+
121+ if (threadIdx .x == 0 ) {
122+ atomicAdd (count, block_counts[threadIdx .x ]);
123+ }
98124 }
99125
100126}
0 commit comments