diff --git a/libavcodec/riscv/ac3dsp_init.c b/libavcodec/riscv/ac3dsp_init.c index be5e153fac..e120aa2dce 100644 --- a/libavcodec/riscv/ac3dsp_init.c +++ b/libavcodec/riscv/ac3dsp_init.c @@ -30,6 +30,8 @@ void ff_extract_exponents_rvb(uint8_t *exp, int32_t *coef, int nb_coefs); void ff_float_to_fixed24_rvv(int32_t *dst, const float *src, size_t len); void ff_sum_square_butterfly_int32_rvv(int64_t *, const int32_t *, const int32_t *, int); +void ff_sum_square_butterfly_float_rvv(float *, const float *, + const float *, int); av_cold void ff_ac3dsp_init_riscv(AC3DSPContext *c) { @@ -39,8 +41,10 @@ av_cold void ff_ac3dsp_init_riscv(AC3DSPContext *c) if (flags & AV_CPU_FLAG_RVB_ADDR) { if (flags & AV_CPU_FLAG_RVB_BASIC) c->extract_exponents = ff_extract_exponents_rvb; - if (flags & AV_CPU_FLAG_RVV_F32) + if (flags & AV_CPU_FLAG_RVV_F32) { c->float_to_fixed24 = ff_float_to_fixed24_rvv; + c->sum_square_butterfly_float = ff_sum_square_butterfly_float_rvv; + } # if __riscv_xlen >= 64 if (flags & AV_CPU_FLAG_RVV_I64) c->sum_square_butterfly_int32 = ff_sum_square_butterfly_int32_rvv; diff --git a/libavcodec/riscv/ac3dsp_rvv.S b/libavcodec/riscv/ac3dsp_rvv.S index dd0b4cd797..397e000ab0 100644 --- a/libavcodec/riscv/ac3dsp_rvv.S +++ b/libavcodec/riscv/ac3dsp_rvv.S @@ -78,3 +78,42 @@ func ff_sum_square_butterfly_int32_rvv, zve64x ret endfunc #endif + +func ff_sum_square_butterfly_float_rvv, zve32f + vsetvli t0, zero, e32, m8, ta, ma + vmv.v.x v0, zero + vmv.v.x v8, zero +1: + vsetvli t0, a3, e32, m4, tu, ma + vle32.v v16, (a1) + sub a3, a3, t0 + vle32.v v20, (a2) + sh2add a1, t0, a1 + vfadd.vv v24, v16, v20 + sh2add a2, t0, a2 + vfsub.vv v28, v16, v20 + vfmacc.vv v0, v16, v16 + vfmacc.vv v4, v20, v20 + vfmacc.vv v8, v24, v24 + vfmacc.vv v12, v28, v28 + bnez a3, 1b + + vsetvli t0, zero, e32, m4, ta, ma + vmv.s.x v16, zero + vmv.s.x v17, zero + vfredsum.vs v16, v0, v16 + vmv.s.x v18, zero + vfredsum.vs v17, v4, v17 + vmv.s.x v19, zero + vfredsum.vs v18, v8, v18 + vfmv.f.s ft0, v16 + vfredsum.vs v19, v12, v19 + vfmv.f.s ft1, v17 + fsw ft0, (a0) + vfmv.f.s ft2, v18 + fsw ft1, 4(a0) + vfmv.f.s ft3, v19 + fsw ft2, 8(a0) + fsw ft3, 12(a0) + ret +endfunc