From fc26dca64e0e5d20bb0fcc8743d073cf5b107264 Mon Sep 17 00:00:00 2001 From: "Guo, Yejun" Date: Tue, 16 Mar 2021 13:02:56 +0800 Subject: [PATCH] lavfi/dnn: add classify support with openvino backend Signed-off-by: Guo, Yejun --- libavfilter/dnn/dnn_backend_openvino.c | 145 +++++++++++++++++++++---- libavfilter/dnn/dnn_io_proc.c | 60 ++++++++++ libavfilter/dnn/dnn_io_proc.h | 1 + libavfilter/dnn_filter_common.c | 21 ++++ libavfilter/dnn_filter_common.h | 2 + libavfilter/dnn_interface.h | 10 +- 6 files changed, 219 insertions(+), 20 deletions(-) diff --git a/libavfilter/dnn/dnn_backend_openvino.c b/libavfilter/dnn/dnn_backend_openvino.c index 4e58ff6d9c..1ff8a720b9 100644 --- a/libavfilter/dnn/dnn_backend_openvino.c +++ b/libavfilter/dnn/dnn_backend_openvino.c @@ -29,6 +29,7 @@ #include "libavutil/avassert.h" #include "libavutil/opt.h" #include "libavutil/avstring.h" +#include "libavutil/detection_bbox.h" #include "../internal.h" #include "queue.h" #include "safe_queue.h" @@ -74,6 +75,7 @@ typedef struct TaskItem { // one task might have multiple inferences typedef struct InferenceItem { TaskItem *task; + uint32_t bbox_index; } InferenceItem; // one request for one call to openvino @@ -182,12 +184,23 @@ static DNNReturnType fill_model_input_ov(OVModel *ov_model, RequestItem *request request->inferences[i] = inference; request->inference_count = i + 1; task = inference->task; - if (task->do_ioproc) { - if (ov_model->model->frame_pre_proc != NULL) { - ov_model->model->frame_pre_proc(task->in_frame, &input, ov_model->model->filter_ctx); - } else { - ff_proc_from_frame_to_dnn(task->in_frame, &input, ov_model->model->func_type, ctx); + switch (task->ov_model->model->func_type) { + case DFT_PROCESS_FRAME: + case DFT_ANALYTICS_DETECT: + if (task->do_ioproc) { + if (ov_model->model->frame_pre_proc != NULL) { + ov_model->model->frame_pre_proc(task->in_frame, &input, ov_model->model->filter_ctx); + } else { + ff_proc_from_frame_to_dnn(task->in_frame, &input, ov_model->model->func_type, ctx); + } } + break; + case DFT_ANALYTICS_CLASSIFY: + ff_frame_to_dnn_classify(task->in_frame, &input, inference->bbox_index, ctx); + break; + default: + av_assert0(!"should not reach here"); + break; } input.data = (uint8_t *)input.data + input.width * input.height * input.channels * get_datatype_size(input.dt); @@ -276,6 +289,13 @@ static void infer_completion_callback(void *args) } task->ov_model->model->detect_post_proc(task->out_frame, &output, 1, task->ov_model->model->filter_ctx); break; + case DFT_ANALYTICS_CLASSIFY: + if (!task->ov_model->model->classify_post_proc) { + av_log(ctx, AV_LOG_ERROR, "classify filter needs to provide post proc\n"); + return; + } + task->ov_model->model->classify_post_proc(task->out_frame, &output, request->inferences[i]->bbox_index, task->ov_model->model->filter_ctx); + break; default: av_assert0(!"should not reach here"); break; @@ -513,7 +533,44 @@ static DNNReturnType get_input_ov(void *model, DNNData *input, const char *input return DNN_ERROR; } -static DNNReturnType extract_inference_from_task(DNNFunctionType func_type, TaskItem *task, Queue *inference_queue) +static int contain_valid_detection_bbox(AVFrame *frame) +{ + AVFrameSideData *sd; + const AVDetectionBBoxHeader *header; + const AVDetectionBBox *bbox; + + sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES); + if (!sd) { // this frame has nothing detected + return 0; + } + + if (!sd->size) { + return 0; + } + + header = (const AVDetectionBBoxHeader *)sd->data; + if (!header->nb_bboxes) { + return 0; + } + + for (uint32_t i = 0; i < header->nb_bboxes; i++) { + bbox = av_get_detection_bbox(header, i); + if (bbox->x < 0 || bbox->w < 0 || bbox->x + bbox->w >= frame->width) { + return 0; + } + if (bbox->y < 0 || bbox->h < 0 || bbox->y + bbox->h >= frame->width) { + return 0; + } + + if (bbox->classify_count == AV_NUM_DETECTION_BBOX_CLASSIFY) { + return 0; + } + } + + return 1; +} + +static DNNReturnType extract_inference_from_task(DNNFunctionType func_type, TaskItem *task, Queue *inference_queue, DNNExecBaseParams *exec_params) { switch (func_type) { case DFT_PROCESS_FRAME: @@ -532,6 +589,45 @@ static DNNReturnType extract_inference_from_task(DNNFunctionType func_type, Task } return DNN_SUCCESS; } + case DFT_ANALYTICS_CLASSIFY: + { + const AVDetectionBBoxHeader *header; + AVFrame *frame = task->in_frame; + AVFrameSideData *sd; + DNNExecClassificationParams *params = (DNNExecClassificationParams *)exec_params; + + task->inference_todo = 0; + task->inference_done = 0; + + if (!contain_valid_detection_bbox(frame)) { + return DNN_SUCCESS; + } + + sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES); + header = (const AVDetectionBBoxHeader *)sd->data; + + for (uint32_t i = 0; i < header->nb_bboxes; i++) { + InferenceItem *inference; + const AVDetectionBBox *bbox = av_get_detection_bbox(header, i); + + if (av_strncasecmp(bbox->detect_label, params->target, sizeof(bbox->detect_label)) != 0) { + continue; + } + + inference = av_malloc(sizeof(*inference)); + if (!inference) { + return DNN_ERROR; + } + task->inference_todo++; + inference->task = task; + inference->bbox_index = i; + if (ff_queue_push_back(inference_queue, inference) < 0) { + av_freep(&inference); + return DNN_ERROR; + } + } + return DNN_SUCCESS; + } default: av_assert0(!"should not reach here"); return DNN_ERROR; @@ -598,7 +694,7 @@ static DNNReturnType get_output_ov(void *model, const char *input_name, int inpu task.out_frame = out_frame; task.ov_model = ov_model; - if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue) != DNN_SUCCESS) { + if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue, NULL) != DNN_SUCCESS) { av_frame_free(&out_frame); av_frame_free(&in_frame); av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n"); @@ -690,6 +786,14 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNExecBaseParams * return DNN_ERROR; } + if (model->func_type == DFT_ANALYTICS_CLASSIFY) { + // Once we add async support for tensorflow backend and native backend, + // we'll combine the two sync/async functions in dnn_interface.h to + // simplify the code in filter, and async will be an option within backends. + // so, do not support now, and classify filter will not call this function. + return DNN_ERROR; + } + if (ctx->options.batch_size > 1) { avpriv_report_missing_feature(ctx, "batch mode for sync execution"); return DNN_ERROR; @@ -710,7 +814,7 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNExecBaseParams * task.out_frame = exec_params->out_frame ? exec_params->out_frame : exec_params->in_frame; task.ov_model = ov_model; - if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue) != DNN_SUCCESS) { + if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue, exec_params) != DNN_SUCCESS) { av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n"); return DNN_ERROR; } @@ -730,6 +834,7 @@ DNNReturnType ff_dnn_execute_model_async_ov(const DNNModel *model, DNNExecBasePa OVContext *ctx = &ov_model->ctx; RequestItem *request; TaskItem *task; + DNNReturnType ret; if (ff_check_exec_params(ctx, DNN_OV, model->func_type, exec_params) != 0) { return DNN_ERROR; @@ -761,23 +866,25 @@ DNNReturnType ff_dnn_execute_model_async_ov(const DNNModel *model, DNNExecBasePa return DNN_ERROR; } - if (extract_inference_from_task(ov_model->model->func_type, task, ov_model->inference_queue) != DNN_SUCCESS) { + if (extract_inference_from_task(model->func_type, task, ov_model->inference_queue, exec_params) != DNN_SUCCESS) { av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n"); return DNN_ERROR; } - if (ff_queue_size(ov_model->inference_queue) < ctx->options.batch_size) { - // not enough inference items queued for a batch - return DNN_SUCCESS; + while (ff_queue_size(ov_model->inference_queue) >= ctx->options.batch_size) { + request = ff_safe_queue_pop_front(ov_model->request_queue); + if (!request) { + av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n"); + return DNN_ERROR; + } + + ret = execute_model_ov(request, ov_model->inference_queue); + if (ret != DNN_SUCCESS) { + return ret; + } } - request = ff_safe_queue_pop_front(ov_model->request_queue); - if (!request) { - av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n"); - return DNN_ERROR; - } - - return execute_model_ov(request, ov_model->inference_queue); + return DNN_SUCCESS; } DNNAsyncStatusType ff_dnn_get_async_result_ov(const DNNModel *model, AVFrame **in, AVFrame **out) diff --git a/libavfilter/dnn/dnn_io_proc.c b/libavfilter/dnn/dnn_io_proc.c index e104cc5064..5f60d68078 100644 --- a/libavfilter/dnn/dnn_io_proc.c +++ b/libavfilter/dnn/dnn_io_proc.c @@ -22,6 +22,7 @@ #include "libavutil/imgutils.h" #include "libswscale/swscale.h" #include "libavutil/avassert.h" +#include "libavutil/detection_bbox.h" DNNReturnType ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx) { @@ -175,6 +176,65 @@ static enum AVPixelFormat get_pixel_format(DNNData *data) return AV_PIX_FMT_BGR24; } +DNNReturnType ff_frame_to_dnn_classify(AVFrame *frame, DNNData *input, uint32_t bbox_index, void *log_ctx) +{ + const AVPixFmtDescriptor *desc; + int offsetx[4], offsety[4]; + uint8_t *bbox_data[4]; + struct SwsContext *sws_ctx; + int linesizes[4]; + enum AVPixelFormat fmt; + int left, top, width, height; + const AVDetectionBBoxHeader *header; + const AVDetectionBBox *bbox; + AVFrameSideData *sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES); + av_assert0(sd); + + header = (const AVDetectionBBoxHeader *)sd->data; + bbox = av_get_detection_bbox(header, bbox_index); + + left = bbox->x; + width = bbox->w; + top = bbox->y; + height = bbox->h; + + fmt = get_pixel_format(input); + sws_ctx = sws_getContext(width, height, frame->format, + input->width, input->height, fmt, + SWS_FAST_BILINEAR, NULL, NULL, NULL); + if (!sws_ctx) { + av_log(log_ctx, AV_LOG_ERROR, "Failed to create scale context for the conversion " + "fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n", + av_get_pix_fmt_name(frame->format), width, height, + av_get_pix_fmt_name(fmt), input->width, input->height); + return DNN_ERROR; + } + + if (av_image_fill_linesizes(linesizes, fmt, input->width) < 0) { + av_log(log_ctx, AV_LOG_ERROR, "unable to get linesizes with av_image_fill_linesizes"); + sws_freeContext(sws_ctx); + return DNN_ERROR; + } + + desc = av_pix_fmt_desc_get(frame->format); + offsetx[1] = offsetx[2] = AV_CEIL_RSHIFT(left, desc->log2_chroma_w); + offsetx[0] = offsetx[3] = left; + + offsety[1] = offsety[2] = AV_CEIL_RSHIFT(top, desc->log2_chroma_h); + offsety[0] = offsety[3] = top; + + for (int k = 0; frame->data[k]; k++) + bbox_data[k] = frame->data[k] + offsety[k] * frame->linesize[k] + offsetx[k]; + + sws_scale(sws_ctx, (const uint8_t *const *)&bbox_data, frame->linesize, + 0, height, + (uint8_t *const *)(&input->data), linesizes); + + sws_freeContext(sws_ctx); + + return DNN_SUCCESS; +} + static DNNReturnType proc_from_frame_to_dnn_analytics(AVFrame *frame, DNNData *input, void *log_ctx) { struct SwsContext *sws_ctx; diff --git a/libavfilter/dnn/dnn_io_proc.h b/libavfilter/dnn/dnn_io_proc.h index 91ad3cb261..16dcdd6d1a 100644 --- a/libavfilter/dnn/dnn_io_proc.h +++ b/libavfilter/dnn/dnn_io_proc.h @@ -32,5 +32,6 @@ DNNReturnType ff_proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, DNNFunctionType func_type, void *log_ctx); DNNReturnType ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx); +DNNReturnType ff_frame_to_dnn_classify(AVFrame *frame, DNNData *input, uint32_t bbox_index, void *log_ctx); #endif diff --git a/libavfilter/dnn_filter_common.c b/libavfilter/dnn_filter_common.c index c085884eb4..52c7a5392a 100644 --- a/libavfilter/dnn_filter_common.c +++ b/libavfilter/dnn_filter_common.c @@ -77,6 +77,12 @@ int ff_dnn_set_detect_post_proc(DnnContext *ctx, DetectPostProc post_proc) return 0; } +int ff_dnn_set_classify_post_proc(DnnContext *ctx, ClassifyPostProc post_proc) +{ + ctx->model->classify_post_proc = post_proc; + return 0; +} + DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input) { return ctx->model->get_input(ctx->model->model, input, ctx->model_inputname); @@ -112,6 +118,21 @@ DNNReturnType ff_dnn_execute_model_async(DnnContext *ctx, AVFrame *in_frame, AVF return (ctx->dnn_module->execute_model_async)(ctx->model, &exec_params); } +DNNReturnType ff_dnn_execute_model_classification(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame, char *target) +{ + DNNExecClassificationParams class_params = { + { + .input_name = ctx->model_inputname, + .output_names = (const char **)&ctx->model_outputname, + .nb_output = 1, + .in_frame = in_frame, + .out_frame = out_frame, + }, + .target = target, + }; + return (ctx->dnn_module->execute_model_async)(ctx->model, &class_params.base); +} + DNNAsyncStatusType ff_dnn_get_async_result(DnnContext *ctx, AVFrame **in_frame, AVFrame **out_frame) { return (ctx->dnn_module->get_async_result)(ctx->model, in_frame, out_frame); diff --git a/libavfilter/dnn_filter_common.h b/libavfilter/dnn_filter_common.h index 8deb18b39a..e7736d2bac 100644 --- a/libavfilter/dnn_filter_common.h +++ b/libavfilter/dnn_filter_common.h @@ -50,10 +50,12 @@ typedef struct DnnContext { int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx); int ff_dnn_set_frame_proc(DnnContext *ctx, FramePrePostProc pre_proc, FramePrePostProc post_proc); int ff_dnn_set_detect_post_proc(DnnContext *ctx, DetectPostProc post_proc); +int ff_dnn_set_classify_post_proc(DnnContext *ctx, ClassifyPostProc post_proc); DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input); DNNReturnType ff_dnn_get_output(DnnContext *ctx, int input_width, int input_height, int *output_width, int *output_height); DNNReturnType ff_dnn_execute_model(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame); DNNReturnType ff_dnn_execute_model_async(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame); +DNNReturnType ff_dnn_execute_model_classification(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame, char *target); DNNAsyncStatusType ff_dnn_get_async_result(DnnContext *ctx, AVFrame **in_frame, AVFrame **out_frame); DNNReturnType ff_dnn_flush(DnnContext *ctx); void ff_dnn_uninit(DnnContext *ctx); diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h index 941670675d..799244ee14 100644 --- a/libavfilter/dnn_interface.h +++ b/libavfilter/dnn_interface.h @@ -52,7 +52,7 @@ typedef enum { DFT_NONE, DFT_PROCESS_FRAME, // process the whole frame DFT_ANALYTICS_DETECT, // detect from the whole frame - // we can add more such as detect_from_crop, classify_from_bbox, etc. + DFT_ANALYTICS_CLASSIFY, // classify for each bounding box }DNNFunctionType; typedef struct DNNData{ @@ -71,8 +71,14 @@ typedef struct DNNExecBaseParams { AVFrame *out_frame; } DNNExecBaseParams; +typedef struct DNNExecClassificationParams { + DNNExecBaseParams base; + const char *target; +} DNNExecClassificationParams; + typedef int (*FramePrePostProc)(AVFrame *frame, DNNData *model, AVFilterContext *filter_ctx); typedef int (*DetectPostProc)(AVFrame *frame, DNNData *output, uint32_t nb, AVFilterContext *filter_ctx); +typedef int (*ClassifyPostProc)(AVFrame *frame, DNNData *output, uint32_t bbox_index, AVFilterContext *filter_ctx); typedef struct DNNModel{ // Stores model that can be different for different backends. @@ -97,6 +103,8 @@ typedef struct DNNModel{ FramePrePostProc frame_post_proc; // set the post process to interpret detect result from DNNData DetectPostProc detect_post_proc; + // set the post process to interpret classify result from DNNData + ClassifyPostProc classify_post_proc; } DNNModel; // Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.