ffmpeg/libavfilter/vulkan/prefix_sum.comp

152 lines
5.1 KiB
Plaintext

#extension GL_EXT_buffer_reference : require
#extension GL_EXT_buffer_reference2 : require
#define ACQUIRE gl_StorageSemanticsBuffer, gl_SemanticsAcquire
#define RELEASE gl_StorageSemanticsBuffer, gl_SemanticsRelease
// These correspond to X, A, P respectively in the prefix sum paper.
#define FLAG_NOT_READY 0u
#define FLAG_AGGREGATE_READY 1u
#define FLAG_PREFIX_READY 2u
layout(buffer_reference, buffer_reference_align = T_ALIGN) nonprivate buffer StateData {
DTYPE aggregate;
DTYPE prefix;
uint flag;
};
shared DTYPE sh_scratch[WG_SIZE];
shared DTYPE sh_prefix;
shared uint sh_part_ix;
shared uint sh_flag;
void prefix_sum(DataBuffer dst, uint dst_stride, DataBuffer src, uint src_stride)
{
DTYPE local[N_ROWS];
// Determine partition to process by atomic counter (described in Section 4.4 of prefix sum paper).
if (gl_GlobalInvocationID.x == 0)
sh_part_ix = gl_WorkGroupID.x;
// sh_part_ix = atomicAdd(part_counter, 1);
barrier();
uint part_ix = sh_part_ix;
uint ix = part_ix * PARTITION_SIZE + gl_LocalInvocationID.x * N_ROWS;
// TODO: gate buffer read? (evaluate whether shader check or CPU-side padding is better)
local[0] = src.v[ix*src_stride];
for (uint i = 1; i < N_ROWS; i++)
local[i] = local[i - 1] + src.v[(ix + i)*src_stride];
DTYPE agg = local[N_ROWS - 1];
sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i = 0; i < LG_WG_SIZE; i++) {
barrier();
if (gl_LocalInvocationID.x >= (1u << i))
agg += sh_scratch[gl_LocalInvocationID.x - (1u << i)];
barrier();
sh_scratch[gl_LocalInvocationID.x] = agg;
}
// Publish aggregate for this partition
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
state[part_ix].aggregate = agg;
if (part_ix == 0)
state[0].prefix = agg;
}
// Write flag with release semantics
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
uint flag = part_ix == 0 ? FLAG_PREFIX_READY : FLAG_AGGREGATE_READY;
atomicStore(state[part_ix].flag, flag, gl_ScopeDevice, RELEASE);
}
DTYPE exclusive = DTYPE(0);
if (part_ix != 0) {
// step 4 of paper: decoupled lookback
uint look_back_ix = part_ix - 1;
DTYPE their_agg;
uint their_ix = 0;
while (true) {
// Read flag with acquire semantics.
if (gl_LocalInvocationID.x == WG_SIZE - 1)
sh_flag = atomicLoad(state[look_back_ix].flag, gl_ScopeDevice, ACQUIRE);
// The flag load is done only in the last thread. However, because the
// translation of memoryBarrierBuffer to Metal requires uniform control
// flow, we broadcast it to all threads.
barrier();
uint flag = sh_flag;
barrier();
if (flag == FLAG_PREFIX_READY) {
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
DTYPE their_prefix = state[look_back_ix].prefix;
exclusive = their_prefix + exclusive;
}
break;
} else if (flag == FLAG_AGGREGATE_READY) {
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
their_agg = state[look_back_ix].aggregate;
exclusive = their_agg + exclusive;
}
look_back_ix--;
their_ix = 0;
continue;
} // else spins
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
// Unfortunately there's no guarantee of forward progress of other
// workgroups, so compute a bit of the aggregate before trying again.
// In the worst case, spinning stops when the aggregate is complete.
DTYPE m = src.v[(look_back_ix * PARTITION_SIZE + their_ix)*src_stride];
if (their_ix == 0)
their_agg = m;
else
their_agg += m;
their_ix++;
if (their_ix == PARTITION_SIZE) {
exclusive = their_agg + exclusive;
if (look_back_ix == 0) {
sh_flag = FLAG_PREFIX_READY;
} else {
look_back_ix--;
their_ix = 0;
}
}
}
barrier();
flag = sh_flag;
barrier();
if (flag == FLAG_PREFIX_READY)
break;
}
// step 5 of paper: compute inclusive prefix
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
DTYPE inclusive_prefix = exclusive + agg;
sh_prefix = exclusive;
state[part_ix].prefix = inclusive_prefix;
}
if (gl_LocalInvocationID.x == WG_SIZE - 1)
atomicStore(state[part_ix].flag, FLAG_PREFIX_READY, gl_ScopeDevice, RELEASE);
}
barrier();
if (part_ix != 0)
exclusive = sh_prefix;
DTYPE row = exclusive;
if (gl_LocalInvocationID.x > 0)
row += sh_scratch[gl_LocalInvocationID.x - 1];
// note - may overwrite
for (uint i = 0; i < N_ROWS; i++)
dst.v[(ix + i)*dst_stride] = row + local[i];
}