From 6c814093d8a0351d0a5f5264deba2285a436e88a Mon Sep 17 00:00:00 2001 From: Paul B Mahol Date: Sun, 4 Dec 2022 17:32:04 +0100 Subject: [PATCH] avfilter/vf_bm3d: switch to TX from lavu --- configure | 3 - libavfilter/vf_bm3d.c | 332 +++++++++++++++++++++--------------------- 2 files changed, 166 insertions(+), 169 deletions(-) diff --git a/configure b/configure index 0d754e7ae9..f4eedfc207 100755 --- a/configure +++ b/configure @@ -3629,8 +3629,6 @@ avgblur_vulkan_filter_deps="vulkan spirv_compiler" azmq_filter_deps="libzmq" blackframe_filter_deps="gpl" blend_vulkan_filter_deps="vulkan spirv_compiler" -bm3d_filter_deps="avcodec" -bm3d_filter_select="dct" boxblur_filter_deps="gpl" boxblur_opencl_filter_deps="opencl gpl" bs2b_filter_deps="libbs2b" @@ -7444,7 +7442,6 @@ enabled zlib && add_cppflags -DZLIB_CONST # conditional library dependencies, in any order enabled amovie_filter && prepend avfilter_deps "avformat avcodec" enabled aresample_filter && prepend avfilter_deps "swresample" -enabled bm3d_filter && prepend avfilter_deps "avcodec" enabled cover_rect_filter && prepend avfilter_deps "avformat avcodec" enabled ebur128_filter && enabled swresample && prepend avfilter_deps "swresample" enabled elbg_filter && prepend avfilter_deps "avcodec" diff --git a/libavfilter/vf_bm3d.c b/libavfilter/vf_bm3d.c index 1167027535..14f94cf535 100644 --- a/libavfilter/vf_bm3d.c +++ b/libavfilter/vf_bm3d.c @@ -25,17 +25,17 @@ /** * @todo - * - non-power of 2 DCT * - opponent color space * - temporal support */ #include +#include "libavutil/cpu.h" #include "libavutil/imgutils.h" #include "libavutil/opt.h" #include "libavutil/pixdesc.h" -#include "libavcodec/avfft.h" +#include "libavutil/tx.h" #include "avfilter.h" #include "filters.h" #include "formats.h" @@ -69,16 +69,19 @@ typedef struct PosPairCode { } PosPairCode; typedef struct SliceContext { - DCTContext *gdctf, *gdcti; - DCTContext *dctf, *dcti; - FFTSample *bufferh; - FFTSample *bufferv; - FFTSample *bufferz; - FFTSample *buffer; - FFTSample *rbufferh; - FFTSample *rbufferv; - FFTSample *rbufferz; - FFTSample *rbuffer; + AVTXContext *gdctf, *gdcti; + av_tx_fn tx_fn_g, itx_fn_g; + AVTXContext *dctf, *dcti; + av_tx_fn tx_fn, itx_fn; + float *bufferh; + float *buffert; + float *bufferv; + float *bufferz; + float *buffer; + float *rbufferh; + float *rbufferv; + float *rbufferz; + float *rbuffer; float *num, *den; PosPairCode match_blocks[256]; int nb_match_blocks; @@ -105,7 +108,7 @@ typedef struct BM3DContext { int nb_planes; int planewidth[4]; int planeheight[4]; - int group_bits; + int pblock_size; int pgroup_size; SliceContext slices[MAX_NB_THREADS]; @@ -128,11 +131,12 @@ typedef struct BM3DContext { #define OFFSET(x) offsetof(BM3DContext, x) #define FLAGS AV_OPT_FLAG_FILTERING_PARAM|AV_OPT_FLAG_VIDEO_PARAM + static const AVOption bm3d_options[] = { { "sigma", "set denoising strength", OFFSET(sigma), AV_OPT_TYPE_FLOAT, {.dbl=1}, 0, 99999.9, FLAGS }, - { "block", "set log2(size) of local patch", - OFFSET(block_size), AV_OPT_TYPE_INT, {.i64=4}, 4, 6, FLAGS }, + { "block", "set size of local patch", + OFFSET(block_size), AV_OPT_TYPE_INT, {.i64=16}, 8, 64, FLAGS }, { "bstep", "set sliding step for processing blocks", OFFSET(block_step), AV_OPT_TYPE_INT, {.i64=4}, 1, 64, FLAGS }, { "group", "set maximal number of similar blocks", @@ -273,9 +277,9 @@ static void do_block_matching_multi(BM3DContext *s, const uint8_t *src, int src_ double MSE2SSE = s->group_size * s->block_size * s->block_size * src_range * src_range / (s->max * s->max); double distMul = 1. / MSE2SSE; double th_sse = th_mse * MSE2SSE; - int i, index = sc->nb_match_blocks; + int index = sc->nb_match_blocks; - for (i = 0; i < search_size; i++) { + for (int i = 0; i < search_size; i++) { PosCode pos = search_pos[i]; double dist; @@ -316,10 +320,10 @@ static void block_matching_multi(BM3DContext *s, const uint8_t *ref, int ref_lin int r = search_boundary(width - block_size, range, step, 0, y, x); int t = search_boundary(0, range, step, 1, y, x); int b = search_boundary(height - block_size, range, step, 1, y, x); - int j, i, index = 0; + int index = 0; - for (j = t; j <= b; j += step) { - for (i = l; i <= r; i += step) { + for (int j = t; j <= b; j += step) { + for (int i = l; i <= r; i += step) { PosCode pos; if (exclude_cur_pos > 0 && j == y && i == x) { @@ -364,22 +368,18 @@ static void get_block_row(const uint8_t *srcp, int src_linesize, int y, int x, int block_size, float *dst) { const uint8_t *src = srcp + y * src_linesize + x; - int j; - for (j = 0; j < block_size; j++) { + for (int j = 0; j < block_size; j++) dst[j] = src[j]; - } } static void get_block_row16(const uint8_t *srcp, int src_linesize, int y, int x, int block_size, float *dst) { const uint16_t *src = (uint16_t *)srcp + y * src_linesize / 2 + x; - int j; - for (j = 0; j < block_size; j++) { + for (int j = 0; j < block_size; j++) dst[j] = src[j]; - } } static void basic_block_filtering(BM3DContext *s, const uint8_t *src, int src_linesize, @@ -387,7 +387,8 @@ static void basic_block_filtering(BM3DContext *s, const uint8_t *src, int src_li int y, int x, int plane, int jobnr) { SliceContext *sc = &s->slices[jobnr]; - const int buffer_linesize = s->block_size * s->block_size; + const int pblock_size = s->pblock_size; + const int buffer_linesize = s->pblock_size * s->pblock_size; const int nb_match_blocks = sc->nb_match_blocks; const int block_size = s->block_size; const int width = s->planewidth[plane]; @@ -395,54 +396,50 @@ static void basic_block_filtering(BM3DContext *s, const uint8_t *src, int src_li const int group_size = s->group_size; float *buffer = sc->buffer; float *bufferh = sc->bufferh; + float *buffert = sc->buffert; float *bufferv = sc->bufferv; float *bufferz = sc->bufferz; float threshold[4]; float den_weight, num_weight; int retained = 0; - int i, j, k; - for (k = 0; k < nb_match_blocks; k++) { + for (int k = 0; k < nb_match_blocks; k++) { const int y = sc->match_blocks[k].y; const int x = sc->match_blocks[k].x; - for (i = 0; i < block_size; i++) { - s->get_block_row(src, src_linesize, y + i, x, block_size, bufferh + block_size * i); - av_dct_calc(sc->dctf, bufferh + block_size * i); + for (int i = 0; i < block_size; i++) { + s->get_block_row(src, src_linesize, y + i, x, block_size, bufferh + pblock_size * i); + sc->tx_fn(sc->dctf, buffert, bufferh + pblock_size * i, sizeof(float)); + for (int j = 0; j < block_size; j++) + bufferv[j * pblock_size + i] = buffert[j]; } - for (i = 0; i < block_size; i++) { - for (j = 0; j < block_size; j++) { - bufferv[i * block_size + j] = bufferh[j * block_size + i]; - } - av_dct_calc(sc->dctf, bufferv + i * block_size); - } - - for (i = 0; i < block_size; i++) { - memcpy(buffer + k * buffer_linesize + i * block_size, - bufferv + i * block_size, block_size * 4); + for (int i = 0; i < block_size; i++) { + sc->tx_fn(sc->dctf, buffert, bufferv + i * pblock_size, sizeof(float)); + memcpy(buffer + k * buffer_linesize + i * pblock_size, + buffert, block_size * sizeof(float)); } } - for (i = 0; i < block_size; i++) { - for (j = 0; j < block_size; j++) { - for (k = 0; k < nb_match_blocks; k++) - bufferz[k] = buffer[buffer_linesize * k + i * block_size + j]; + for (int i = 0; i < block_size; i++) { + for (int j = 0; j < block_size; j++) { + for (int k = 0; k < nb_match_blocks; k++) + bufferz[k] = buffer[buffer_linesize * k + i * pblock_size + j]; if (group_size > 1) - av_dct_calc(sc->gdctf, bufferz); + sc->tx_fn_g(sc->gdctf, bufferz, bufferz, sizeof(float)); bufferz += pgroup_size; } } - threshold[0] = s->hard_threshold * s->sigma * M_SQRT2 * block_size * block_size * (1 << (s->depth - 8)) / 255.f; + threshold[0] = s->hard_threshold * s->sigma * M_SQRT2 * 4.f * block_size * block_size * (1 << (s->depth - 8)) / 255.f; threshold[1] = threshold[0] * sqrtf(2.f); threshold[2] = threshold[0] * 2.f; threshold[3] = threshold[0] * sqrtf(8.f); bufferz = sc->bufferz; - for (i = 0; i < block_size; i++) { - for (j = 0; j < block_size; j++) { - for (k = 0; k < nb_match_blocks; k++) { + for (int i = 0; i < block_size; i++) { + for (int j = 0; j < block_size; j++) { + for (int k = 0; k < nb_match_blocks; k++) { const float thresh = threshold[(j == 0) + (i == 0) + (k == 0)]; if (bufferz[k] > thresh || bufferz[k] < -thresh) { @@ -457,13 +454,12 @@ static void basic_block_filtering(BM3DContext *s, const uint8_t *src, int src_li bufferz = sc->bufferz; buffer = sc->buffer; - for (i = 0; i < block_size; i++) { - for (j = 0; j < block_size; j++) { + for (int i = 0; i < block_size; i++) { + for (int j = 0; j < block_size; j++) { if (group_size > 1) - av_dct_calc(sc->gdcti, bufferz); - for (k = 0; k < nb_match_blocks; k++) { - buffer[buffer_linesize * k + i * block_size + j] = bufferz[k]; - } + sc->itx_fn_g(sc->gdcti, bufferz, bufferz, sizeof(float)); + for (int k = 0; k < nb_match_blocks; k++) + buffer[buffer_linesize * k + i * pblock_size + j] = bufferz[k]; bufferz += pgroup_size; } } @@ -472,27 +468,26 @@ static void basic_block_filtering(BM3DContext *s, const uint8_t *src, int src_li num_weight = den_weight; buffer = sc->buffer; - for (k = 0; k < nb_match_blocks; k++) { + for (int k = 0; k < nb_match_blocks; k++) { float *num = sc->num + y * width + x; float *den = sc->den + y * width + x; - for (i = 0; i < block_size; i++) { - memcpy(bufferv + i * block_size, - buffer + k * buffer_linesize + i * block_size, - block_size * 4); + for (int i = 0; i < block_size; i++) { + memcpy(bufferv + i * pblock_size, + buffer + k * buffer_linesize + i * pblock_size, + block_size * sizeof(float)); } - for (i = 0; i < block_size; i++) { - av_dct_calc(sc->dcti, bufferv + block_size * i); - for (j = 0; j < block_size; j++) { - bufferh[j * block_size + i] = bufferv[i * block_size + j]; - } + for (int i = 0; i < block_size; i++) { + sc->itx_fn(sc->dcti, buffert, bufferv + i * pblock_size, sizeof(float)); + for (int j = 0; j < block_size; j++) + bufferh[j * pblock_size + i] = buffert[j]; } - for (i = 0; i < block_size; i++) { - av_dct_calc(sc->dcti, bufferh + block_size * i); - for (j = 0; j < block_size; j++) { - num[j] += bufferh[i * block_size + j] * num_weight; + for (int i = 0; i < block_size; i++) { + sc->itx_fn(sc->dcti, buffert, bufferh + pblock_size * i, sizeof(float)); + for (int j = 0; j < block_size; j++) { + num[j] += buffert[j] * num_weight; den[j] += den_weight; } num += width; @@ -506,7 +501,8 @@ static void final_block_filtering(BM3DContext *s, const uint8_t *src, int src_li int y, int x, int plane, int jobnr) { SliceContext *sc = &s->slices[jobnr]; - const int buffer_linesize = s->block_size * s->block_size; + const int pblock_size = s->pblock_size; + const int buffer_linesize = s->pblock_size * s->pblock_size; const int nb_match_blocks = sc->nb_match_blocks; const int block_size = s->block_size; const int width = s->planewidth[plane]; @@ -523,45 +519,44 @@ static void final_block_filtering(BM3DContext *s, const uint8_t *src, int src_li float *rbufferz = sc->rbufferz; float den_weight, num_weight; float l2_wiener = 0; - int i, j, k; - for (k = 0; k < nb_match_blocks; k++) { + for (int k = 0; k < nb_match_blocks; k++) { const int y = sc->match_blocks[k].y; const int x = sc->match_blocks[k].x; - for (i = 0; i < block_size; i++) { - s->get_block_row(src, src_linesize, y + i, x, block_size, bufferh + block_size * i); - s->get_block_row(ref, ref_linesize, y + i, x, block_size, rbufferh + block_size * i); - av_dct_calc(sc->dctf, bufferh + block_size * i); - av_dct_calc(sc->dctf, rbufferh + block_size * i); + for (int i = 0; i < block_size; i++) { + s->get_block_row(src, src_linesize, y + i, x, block_size, bufferh + pblock_size * i); + s->get_block_row(ref, ref_linesize, y + i, x, block_size, rbufferh + pblock_size * i); + sc->tx_fn(sc->dctf, bufferh + pblock_size * i, bufferh + pblock_size * i, sizeof(float)); + sc->tx_fn(sc->dctf, rbufferh + pblock_size * i, rbufferh + pblock_size * i, sizeof(float)); } - for (i = 0; i < block_size; i++) { - for (j = 0; j < block_size; j++) { - bufferv[i * block_size + j] = bufferh[j * block_size + i]; - rbufferv[i * block_size + j] = rbufferh[j * block_size + i]; + for (int i = 0; i < block_size; i++) { + for (int j = 0; j < block_size; j++) { + bufferv[i * pblock_size + j] = bufferh[j * pblock_size + i]; + rbufferv[i * pblock_size + j] = rbufferh[j * pblock_size + i]; } - av_dct_calc(sc->dctf, bufferv + i * block_size); - av_dct_calc(sc->dctf, rbufferv + i * block_size); + sc->tx_fn(sc->dctf, bufferv + i * pblock_size, bufferv + i * pblock_size, sizeof(float)); + sc->tx_fn(sc->dctf, rbufferv + i * pblock_size, rbufferv + i * pblock_size, sizeof(float)); } - for (i = 0; i < block_size; i++) { - memcpy(buffer + k * buffer_linesize + i * block_size, - bufferv + i * block_size, block_size * 4); - memcpy(rbuffer + k * buffer_linesize + i * block_size, - rbufferv + i * block_size, block_size * 4); + for (int i = 0; i < block_size; i++) { + memcpy(buffer + k * buffer_linesize + i * pblock_size, + bufferv + i * pblock_size, block_size * sizeof(float)); + memcpy(rbuffer + k * buffer_linesize + i * pblock_size, + rbufferv + i * pblock_size, block_size * sizeof(float)); } } - for (i = 0; i < block_size; i++) { - for (j = 0; j < block_size; j++) { - for (k = 0; k < nb_match_blocks; k++) { - bufferz[k] = buffer[buffer_linesize * k + i * block_size + j]; - rbufferz[k] = rbuffer[buffer_linesize * k + i * block_size + j]; + for (int i = 0; i < block_size; i++) { + for (int j = 0; j < block_size; j++) { + for (int k = 0; k < nb_match_blocks; k++) { + bufferz[k] = buffer[buffer_linesize * k + i * pblock_size + j]; + rbufferz[k] = rbuffer[buffer_linesize * k + i * pblock_size + j]; } if (group_size > 1) { - av_dct_calc(sc->gdctf, bufferz); - av_dct_calc(sc->gdctf, rbufferz); + sc->tx_fn_g(sc->gdctf, bufferz, bufferz, sizeof(float)); + sc->tx_fn_g(sc->gdctf, rbufferz, rbufferz, sizeof(float)); } bufferz += pgroup_size; rbufferz += pgroup_size; @@ -571,9 +566,9 @@ static void final_block_filtering(BM3DContext *s, const uint8_t *src, int src_li bufferz = sc->bufferz; rbufferz = sc->rbufferz; - for (i = 0; i < block_size; i++) { - for (j = 0; j < block_size; j++) { - for (k = 0; k < nb_match_blocks; k++) { + for (int i = 0; i < block_size; i++) { + for (int j = 0; j < block_size; j++) { + for (int k = 0; k < nb_match_blocks; k++) { const float ref_sqr = rbufferz[k] * rbufferz[k]; float wiener_coef = ref_sqr / (ref_sqr + sigma_sqr); @@ -589,12 +584,12 @@ static void final_block_filtering(BM3DContext *s, const uint8_t *src, int src_li bufferz = sc->bufferz; buffer = sc->buffer; - for (i = 0; i < block_size; i++) { - for (j = 0; j < block_size; j++) { + for (int i = 0; i < block_size; i++) { + for (int j = 0; j < block_size; j++) { if (group_size > 1) - av_dct_calc(sc->gdcti, bufferz); - for (k = 0; k < nb_match_blocks; k++) { - buffer[buffer_linesize * k + i * block_size + j] = bufferz[k]; + sc->itx_fn_g(sc->gdcti, bufferz, bufferz, sizeof(float)); + for (int k = 0; k < nb_match_blocks; k++) { + buffer[buffer_linesize * k + i * pblock_size + j] = bufferz[k]; } bufferz += pgroup_size; } @@ -604,27 +599,27 @@ static void final_block_filtering(BM3DContext *s, const uint8_t *src, int src_li den_weight = 1.f / l2_wiener; num_weight = den_weight; - for (k = 0; k < nb_match_blocks; k++) { + for (int k = 0; k < nb_match_blocks; k++) { float *num = sc->num + y * width + x; float *den = sc->den + y * width + x; - for (i = 0; i < block_size; i++) { - memcpy(bufferv + i * block_size, - buffer + k * buffer_linesize + i * block_size, - block_size * 4); + for (int i = 0; i < block_size; i++) { + memcpy(bufferv + i * pblock_size, + buffer + k * buffer_linesize + i * pblock_size, + block_size * sizeof(float)); } - for (i = 0; i < block_size; i++) { - av_dct_calc(sc->dcti, bufferv + block_size * i); - for (j = 0; j < block_size; j++) { - bufferh[j * block_size + i] = bufferv[i * block_size + j]; + for (int i = 0; i < block_size; i++) { + sc->itx_fn(sc->dcti, bufferv + pblock_size * i, bufferv + pblock_size * i, sizeof(float)); + for (int j = 0; j < block_size; j++) { + bufferh[j * pblock_size + i] = bufferv[i * pblock_size + j]; } } - for (i = 0; i < block_size; i++) { - av_dct_calc(sc->dcti, bufferh + block_size * i); - for (j = 0; j < block_size; j++) { - num[j] += bufferh[i * block_size + j] * num_weight; + for (int i = 0; i < block_size; i++) { + sc->itx_fn(sc->dcti, bufferh + pblock_size * i, bufferh + pblock_size * i, sizeof(float)); + for (int j = 0; j < block_size; j++) { + num[j] += bufferh[i * pblock_size + j] * num_weight; den[j] += den_weight; } num += width; @@ -638,15 +633,14 @@ static void do_output(BM3DContext *s, uint8_t *dst, int dst_linesize, { const int height = s->planeheight[plane]; const int width = s->planewidth[plane]; - int i, j, k; - for (i = 0; i < height; i++) { - for (j = 0; j < width; j++) { + for (int i = 0; i < height; i++) { + for (int j = 0; j < width; j++) { uint8_t *dstp = dst + i * dst_linesize; float sum_den = 0.f; float sum_num = 0.f; - for (k = 0; k < nb_jobs; k++) { + for (int k = 0; k < nb_jobs; k++) { SliceContext *sc = &s->slices[k]; float num = sc->num[i * width + j]; float den = sc->den[i * width + j]; @@ -666,15 +660,14 @@ static void do_output16(BM3DContext *s, uint8_t *dst, int dst_linesize, const int height = s->planeheight[plane]; const int width = s->planewidth[plane]; const int depth = s->depth; - int i, j, k; - for (i = 0; i < height; i++) { - for (j = 0; j < width; j++) { + for (int i = 0; i < height; i++) { + for (int j = 0; j < width; j++) { uint16_t *dstp = (uint16_t *)dst + i * dst_linesize / 2; float sum_den = 0.f; float sum_num = 0.f; - for (k = 0; k < nb_jobs; k++) { + for (int k = 0; k < nb_jobs; k++) { SliceContext *sc = &s->slices[k]; float num = sc->num[i * width + j]; float den = sc->den[i * width + j]; @@ -706,17 +699,16 @@ static int filter_slice(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs) const int slice_start = (((height + block_step - 1) / block_step) * jobnr / nb_jobs) * block_step; const int slice_end = (jobnr == nb_jobs - 1) ? block_pos_bottom + block_step : (((height + block_step - 1) / block_step) * (jobnr + 1) / nb_jobs) * block_step; - int i, j; - memset(sc->num, 0, width * height * sizeof(FFTSample)); - memset(sc->den, 0, width * height * sizeof(FFTSample)); + memset(sc->num, 0, width * height * sizeof(float)); + memset(sc->den, 0, width * height * sizeof(float)); - for (j = slice_start; j < slice_end; j += block_step) { + for (int j = slice_start; j < slice_end; j += block_step) { if (j > block_pos_bottom) { j = block_pos_bottom; } - for (i = 0; i < block_pos_right + block_step; i += block_step) { + for (int i = 0; i < block_pos_right + block_step; i += block_step) { if (i > block_pos_right) { i = block_pos_right; } @@ -749,7 +741,7 @@ static int filter_frame(AVFilterContext *ctx, AVFrame **out, AVFrame *in, AVFram if (!((1 << p) & s->planes) || ctx->is_disabled) { av_image_copy_plane((*out)->data[p], (*out)->linesize[p], in->data[p], in->linesize[p], - s->planewidth[p], s->planeheight[p]); + s->planewidth[p] * (1 + (s->depth > 8)), s->planeheight[p]); continue; } @@ -773,7 +765,6 @@ static int config_input(AVFilterLink *inlink) const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(inlink->format); AVFilterContext *ctx = inlink->dst; BM3DContext *s = ctx->priv; - int i, group_bits; s->nb_threads = FFMIN(ff_filter_get_nb_threads(ctx), MAX_NB_THREADS); s->nb_planes = av_pix_fmt_count_planes(inlink->format); @@ -783,43 +774,53 @@ static int config_input(AVFilterLink *inlink) s->planeheight[0] = s->planeheight[3] = inlink->h; s->planewidth[1] = s->planewidth[2] = AV_CEIL_RSHIFT(inlink->w, desc->log2_chroma_w); s->planewidth[0] = s->planewidth[3] = inlink->w; + s->pblock_size = FFALIGN(s->block_size * 2, av_cpu_max_align()); + s->pgroup_size = FFALIGN(s->group_size * 2, av_cpu_max_align()); - for (group_bits = 4; 1 << group_bits < s->group_size; group_bits++); - s->group_bits = group_bits; - s->pgroup_size = 1 << group_bits; - - for (i = 0; i < s->nb_threads; i++) { + for (int i = 0; i < s->nb_threads; i++) { SliceContext *sc = &s->slices[i]; + float iscale = 0.5f / s->block_size; + float scale = 1.f; + int ret; - sc->num = av_calloc(FFALIGN(s->planewidth[0], s->block_size) * FFALIGN(s->planeheight[0], s->block_size), sizeof(FFTSample)); - sc->den = av_calloc(FFALIGN(s->planewidth[0], s->block_size) * FFALIGN(s->planeheight[0], s->block_size), sizeof(FFTSample)); + sc->num = av_calloc(FFALIGN(s->planewidth[0], s->block_size) * FFALIGN(s->planeheight[0], s->block_size), sizeof(float)); + sc->den = av_calloc(FFALIGN(s->planewidth[0], s->block_size) * FFALIGN(s->planeheight[0], s->block_size), sizeof(float)); if (!sc->num || !sc->den) return AVERROR(ENOMEM); - sc->dctf = av_dct_init(av_log2(s->block_size), DCT_II); - sc->dcti = av_dct_init(av_log2(s->block_size), DCT_III); - if (!sc->dctf || !sc->dcti) - return AVERROR(ENOMEM); + ret = av_tx_init(&sc->dctf, &sc->tx_fn, AV_TX_FLOAT_DCT, 0, s->block_size >> 0, &scale, 0); + if (ret < 0) + return ret; - if (s->group_bits > 1) { - sc->gdctf = av_dct_init(s->group_bits, DCT_II); - sc->gdcti = av_dct_init(s->group_bits, DCT_III); - if (!sc->gdctf || !sc->gdcti) - return AVERROR(ENOMEM); + ret = av_tx_init(&sc->dcti, &sc->itx_fn, AV_TX_FLOAT_DCT, 1, s->block_size >> 1, &iscale, 0); + if (ret < 0) + return ret; + + if (s->group_size > 1) { + float iscale = 0.5f / s->group_size; + + ret = av_tx_init(&sc->gdctf, &sc->tx_fn_g, AV_TX_FLOAT_DCT, 0, s->group_size >> 0, &scale, 0); + if (ret < 0) + return ret; + + ret = av_tx_init(&sc->gdcti, &sc->itx_fn_g, AV_TX_FLOAT_DCT, 1, s->group_size >> 1, &iscale, 0); + if (ret < 0) + return ret; } - sc->buffer = av_calloc(s->block_size * s->block_size * s->pgroup_size, sizeof(*sc->buffer)); - sc->bufferz = av_calloc(s->block_size * s->block_size * s->pgroup_size, sizeof(*sc->bufferz)); - sc->bufferh = av_calloc(s->block_size * s->block_size, sizeof(*sc->bufferh)); - sc->bufferv = av_calloc(s->block_size * s->block_size, sizeof(*sc->bufferv)); - if (!sc->bufferh || !sc->bufferv || !sc->buffer || !sc->bufferz) + sc->buffer = av_calloc(s->pblock_size * s->pblock_size * s->pgroup_size, sizeof(*sc->buffer)); + sc->bufferz = av_calloc(s->pblock_size * s->pblock_size * s->pgroup_size, sizeof(*sc->bufferz)); + sc->bufferh = av_calloc(s->pblock_size * s->pblock_size, sizeof(*sc->bufferh)); + sc->bufferv = av_calloc(s->pblock_size * s->pblock_size, sizeof(*sc->bufferv)); + sc->buffert = av_calloc(s->pblock_size, sizeof(*sc->buffert)); + if (!sc->bufferh || !sc->bufferv || !sc->buffer || !sc->bufferz || !sc->buffert) return AVERROR(ENOMEM); if (s->mode == FINAL) { - sc->rbuffer = av_calloc(s->block_size * s->block_size * s->pgroup_size, sizeof(*sc->rbuffer)); - sc->rbufferz = av_calloc(s->block_size * s->block_size * s->pgroup_size, sizeof(*sc->rbufferz)); - sc->rbufferh = av_calloc(s->block_size * s->block_size, sizeof(*sc->rbufferh)); - sc->rbufferv = av_calloc(s->block_size * s->block_size, sizeof(*sc->rbufferv)); + sc->rbuffer = av_calloc(s->pblock_size * s->pblock_size * s->pgroup_size, sizeof(*sc->rbuffer)); + sc->rbufferz = av_calloc(s->pblock_size * s->pblock_size * s->pgroup_size, sizeof(*sc->rbufferz)); + sc->rbufferh = av_calloc(s->pblock_size * s->pblock_size, sizeof(*sc->rbufferh)); + sc->rbufferv = av_calloc(s->pblock_size * s->pblock_size, sizeof(*sc->rbufferv)); if (!sc->rbufferh || !sc->rbufferv || !sc->rbuffer || !sc->rbufferz) return AVERROR(ENOMEM); } @@ -919,13 +920,12 @@ static av_cold int init(AVFilterContext *ctx) return AVERROR_BUG; } - s->block_size = 1 << s->block_size; - if (s->block_step > s->block_size) { av_log(ctx, AV_LOG_WARNING, "bstep: %d can't be bigger than block size. Changing to %d.\n", s->block_step, s->block_size); s->block_step = s->block_size; } + if (s->bm_step > s->bm_range) { av_log(ctx, AV_LOG_WARNING, "mstep: %d can't be bigger than block matching range. Changing to %d.\n", s->bm_step, s->bm_range); @@ -1004,24 +1004,24 @@ static int config_output(AVFilterLink *outlink) static av_cold void uninit(AVFilterContext *ctx) { BM3DContext *s = ctx->priv; - int i; if (s->ref) ff_framesync_uninit(&s->fs); - for (i = 0; i < s->nb_threads; i++) { + for (int i = 0; i < s->nb_threads; i++) { SliceContext *sc = &s->slices[i]; av_freep(&sc->num); av_freep(&sc->den); - av_dct_end(sc->gdctf); - av_dct_end(sc->gdcti); - av_dct_end(sc->dctf); - av_dct_end(sc->dcti); + av_tx_uninit(&sc->gdctf); + av_tx_uninit(&sc->gdcti); + av_tx_uninit(&sc->dctf); + av_tx_uninit(&sc->dcti); av_freep(&sc->buffer); av_freep(&sc->bufferh); + av_freep(&sc->buffert); av_freep(&sc->bufferv); av_freep(&sc->bufferz); av_freep(&sc->rbuffer);