#extension GL_EXT_buffer_reference : require #extension GL_EXT_buffer_reference2 : require #define ACQUIRE gl_StorageSemanticsBuffer, gl_SemanticsAcquire #define RELEASE gl_StorageSemanticsBuffer, gl_SemanticsRelease // These correspond to X, A, P respectively in the prefix sum paper. #define FLAG_NOT_READY 0u #define FLAG_AGGREGATE_READY 1u #define FLAG_PREFIX_READY 2u layout(buffer_reference, buffer_reference_align = T_ALIGN) nonprivate buffer StateData { DTYPE aggregate; DTYPE prefix; uint flag; }; shared DTYPE sh_scratch[WG_SIZE]; shared DTYPE sh_prefix; shared uint sh_part_ix; shared uint sh_flag; void prefix_sum(DataBuffer dst, uint dst_stride, DataBuffer src, uint src_stride) { DTYPE local[N_ROWS]; // Determine partition to process by atomic counter (described in Section 4.4 of prefix sum paper). if (gl_GlobalInvocationID.x == 0) sh_part_ix = gl_WorkGroupID.x; // sh_part_ix = atomicAdd(part_counter, 1); barrier(); uint part_ix = sh_part_ix; uint ix = part_ix * PARTITION_SIZE + gl_LocalInvocationID.x * N_ROWS; // TODO: gate buffer read? (evaluate whether shader check or CPU-side padding is better) local[0] = src.v[ix*src_stride]; for (uint i = 1; i < N_ROWS; i++) local[i] = local[i - 1] + src.v[(ix + i)*src_stride]; DTYPE agg = local[N_ROWS - 1]; sh_scratch[gl_LocalInvocationID.x] = agg; for (uint i = 0; i < LG_WG_SIZE; i++) { barrier(); if (gl_LocalInvocationID.x >= (1u << i)) agg += sh_scratch[gl_LocalInvocationID.x - (1u << i)]; barrier(); sh_scratch[gl_LocalInvocationID.x] = agg; } // Publish aggregate for this partition if (gl_LocalInvocationID.x == WG_SIZE - 1) { state[part_ix].aggregate = agg; if (part_ix == 0) state[0].prefix = agg; } // Write flag with release semantics if (gl_LocalInvocationID.x == WG_SIZE - 1) { uint flag = part_ix == 0 ? FLAG_PREFIX_READY : FLAG_AGGREGATE_READY; atomicStore(state[part_ix].flag, flag, gl_ScopeDevice, RELEASE); } DTYPE exclusive = DTYPE(0); if (part_ix != 0) { // step 4 of paper: decoupled lookback uint look_back_ix = part_ix - 1; DTYPE their_agg; uint their_ix = 0; while (true) { // Read flag with acquire semantics. if (gl_LocalInvocationID.x == WG_SIZE - 1) sh_flag = atomicLoad(state[look_back_ix].flag, gl_ScopeDevice, ACQUIRE); // The flag load is done only in the last thread. However, because the // translation of memoryBarrierBuffer to Metal requires uniform control // flow, we broadcast it to all threads. barrier(); uint flag = sh_flag; barrier(); if (flag == FLAG_PREFIX_READY) { if (gl_LocalInvocationID.x == WG_SIZE - 1) { DTYPE their_prefix = state[look_back_ix].prefix; exclusive = their_prefix + exclusive; } break; } else if (flag == FLAG_AGGREGATE_READY) { if (gl_LocalInvocationID.x == WG_SIZE - 1) { their_agg = state[look_back_ix].aggregate; exclusive = their_agg + exclusive; } look_back_ix--; their_ix = 0; continue; } // else spins if (gl_LocalInvocationID.x == WG_SIZE - 1) { // Unfortunately there's no guarantee of forward progress of other // workgroups, so compute a bit of the aggregate before trying again. // In the worst case, spinning stops when the aggregate is complete. DTYPE m = src.v[(look_back_ix * PARTITION_SIZE + their_ix)*src_stride]; if (their_ix == 0) their_agg = m; else their_agg += m; their_ix++; if (their_ix == PARTITION_SIZE) { exclusive = their_agg + exclusive; if (look_back_ix == 0) { sh_flag = FLAG_PREFIX_READY; } else { look_back_ix--; their_ix = 0; } } } barrier(); flag = sh_flag; barrier(); if (flag == FLAG_PREFIX_READY) break; } // step 5 of paper: compute inclusive prefix if (gl_LocalInvocationID.x == WG_SIZE - 1) { DTYPE inclusive_prefix = exclusive + agg; sh_prefix = exclusive; state[part_ix].prefix = inclusive_prefix; } if (gl_LocalInvocationID.x == WG_SIZE - 1) atomicStore(state[part_ix].flag, FLAG_PREFIX_READY, gl_ScopeDevice, RELEASE); } barrier(); if (part_ix != 0) exclusive = sh_prefix; DTYPE row = exclusive; if (gl_LocalInvocationID.x > 0) row += sh_scratch[gl_LocalInvocationID.x - 1]; // note - may overwrite for (uint i = 0; i < N_ROWS; i++) dst.v[(ix + i)*dst_stride] = row + local[i]; }