diff --git a/libavfilter/dnn/dnn_backend_openvino.c b/libavfilter/dnn/dnn_backend_openvino.c index d3af8c34ce..6fe8b9c243 100644 --- a/libavfilter/dnn/dnn_backend_openvino.c +++ b/libavfilter/dnn/dnn_backend_openvino.c @@ -386,9 +386,9 @@ static void infer_completion_callback(void *args) ov_shape_free(&output_shape); return; } - output.channels = dims[1]; - output.height = dims[2]; - output.width = dims[3]; + output.channels = output_shape.rank > 2 ? dims[output_shape.rank - 3] : 1; + output.height = output_shape.rank > 1 ? dims[output_shape.rank - 2] : 1; + output.width = output_shape.rank > 0 ? dims[output_shape.rank - 1] : 1; av_assert0(request->lltask_count <= dims[0]); ov_shape_free(&output_shape); #else diff --git a/libavfilter/vf_dnn_detect.c b/libavfilter/vf_dnn_detect.c index 9db90ee4cf..7ac3bb0b58 100644 --- a/libavfilter/vf_dnn_detect.c +++ b/libavfilter/vf_dnn_detect.c @@ -30,9 +30,11 @@ #include "libavutil/time.h" #include "libavutil/avstring.h" #include "libavutil/detection_bbox.h" +#include "libavutil/fifo.h" typedef enum { - DDMT_SSD + DDMT_SSD, + DDMT_YOLOV1V2, } DNNDetectionModelType; typedef struct DnnDetectContext { @@ -43,6 +45,15 @@ typedef struct DnnDetectContext { char **labels; int label_count; DNNDetectionModelType model_type; + int cell_w; + int cell_h; + int nb_classes; + AVFifo *bboxes_fifo; + int scale_width; + int scale_height; + char *anchors_str; + float *anchors; + int nb_anchor; } DnnDetectContext; #define OFFSET(x) offsetof(DnnDetectContext, dnnctx.x) @@ -61,11 +72,218 @@ static const AVOption dnn_detect_options[] = { { "labels", "path to labels file", OFFSET2(labels_filename), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS }, { "model_type", "DNN detection model type", OFFSET2(model_type), AV_OPT_TYPE_INT, { .i64 = DDMT_SSD }, INT_MIN, INT_MAX, FLAGS, "model_type" }, { "ssd", "output shape [1, 1, N, 7]", 0, AV_OPT_TYPE_CONST, { .i64 = DDMT_SSD }, 0, 0, FLAGS, "model_type" }, + { "yolo", "output shape [1, N*Cx*Cy*DetectionBox]", 0, AV_OPT_TYPE_CONST, { .i64 = DDMT_YOLOV1V2 }, 0, 0, FLAGS, "model_type" }, + { "cell_w", "cell width", OFFSET2(cell_w), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, INTMAX_MAX, FLAGS }, + { "cell_h", "cell height", OFFSET2(cell_h), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, INTMAX_MAX, FLAGS }, + { "nb_classes", "The number of class", OFFSET2(nb_classes), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, INTMAX_MAX, FLAGS }, + { "anchors", "anchors, splited by '&'", OFFSET2(anchors_str), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS }, { NULL } }; AVFILTER_DEFINE_CLASS(dnn_detect); +static int dnn_detect_get_label_id(int nb_classes, int cell_size, float *label_data) +{ + float max_prob = 0; + int label_id = 0; + for (int i = 0; i < nb_classes; i++) { + if (label_data[i * cell_size] > max_prob) { + max_prob = label_data[i * cell_size]; + label_id = i; + } + } + return label_id; +} + +static int dnn_detect_parse_anchors(char *anchors_str, float **anchors) +{ + char *saveptr = NULL, *token; + float *anchors_buf; + int nb_anchor = 0, i = 0; + while(anchors_str[i] != '\0') { + if(anchors_str[i] == '&') + nb_anchor++; + i++; + } + nb_anchor++; + anchors_buf = av_mallocz(nb_anchor * sizeof(*anchors)); + if (!anchors_buf) { + return 0; + } + for (int i = 0; i < nb_anchor; i++) { + token = av_strtok(anchors_str, "&", &saveptr); + anchors_buf[i] = strtof(token, NULL); + anchors_str = NULL; + } + *anchors = anchors_buf; + return nb_anchor; +} + +/* Calculate Intersection Over Union */ +static float dnn_detect_IOU(AVDetectionBBox *bbox1, AVDetectionBBox *bbox2) +{ + float overlapping_width = FFMIN(bbox1->x + bbox1->w, bbox2->x + bbox2->w) - FFMAX(bbox1->x, bbox2->x); + float overlapping_height = FFMIN(bbox1->y + bbox1->h, bbox2->y + bbox2->h) - FFMAX(bbox1->y, bbox2->y); + float intersection_area = + (overlapping_width < 0 || overlapping_height < 0) ? 0 : overlapping_height * overlapping_width; + float union_area = bbox1->w * bbox1->h + bbox2->w * bbox2->h - intersection_area; + return intersection_area / union_area; +} + +static int dnn_detect_parse_yolo_output(AVFrame *frame, DNNData *output, int output_index, + AVFilterContext *filter_ctx) +{ + DnnDetectContext *ctx = filter_ctx->priv; + float conf_threshold = ctx->confidence; + int detection_boxes, box_size, cell_w, cell_h, scale_w, scale_h; + int nb_classes = ctx->nb_classes; + float *output_data = output[output_index].data; + float *anchors = ctx->anchors; + AVDetectionBBox *bbox; + + if (ctx->model_type == DDMT_YOLOV1V2) { + cell_w = ctx->cell_w; + cell_h = ctx->cell_h; + scale_w = cell_w; + scale_h = cell_h; + } + box_size = nb_classes + 5; + + if (!cell_h || !cell_w) { + av_log(filter_ctx, AV_LOG_ERROR, "cell_w and cell_h are detected\n"); + return AVERROR(EINVAL); + } + + if (!nb_classes) { + av_log(filter_ctx, AV_LOG_ERROR, "nb_classes is not set\n"); + return AVERROR(EINVAL); + } + + if (!anchors) { + av_log(filter_ctx, AV_LOG_ERROR, "anchors is not set\n"); + return AVERROR(EINVAL); + } + + if (output[output_index].channels * output[output_index].width * + output[output_index].height % (box_size * cell_w * cell_h)) { + av_log(filter_ctx, AV_LOG_ERROR, "wrong cell_w, cell_h or nb_classes\n"); + return AVERROR(EINVAL); + } + detection_boxes = output[output_index].channels * + output[output_index].height * + output[output_index].width / box_size / cell_w / cell_h; + + /** + * find all candidate bbox + * yolo output can be reshaped to [B, N*D, Cx, Cy] + * Detection box 'D' has format [`x`, `y`, `h`, `w`, `box_score`, `class_no_1`, ...,] + **/ + for (int box_id = 0; box_id < detection_boxes; box_id++) { + for (int cx = 0; cx < cell_w; cx++) + for (int cy = 0; cy < cell_h; cy++) { + float x, y, w, h, conf; + float *detection_boxes_data; + int label_id; + + detection_boxes_data = output_data + box_id * box_size * cell_w * cell_h; + conf = detection_boxes_data[cy * cell_w + cx + 4 * cell_w * cell_h]; + if (conf < conf_threshold) { + continue; + } + + x = detection_boxes_data[cy * cell_w + cx]; + y = detection_boxes_data[cy * cell_w + cx + cell_w * cell_h]; + w = detection_boxes_data[cy * cell_w + cx + 2 * cell_w * cell_h]; + h = detection_boxes_data[cy * cell_w + cx + 3 * cell_w * cell_h]; + label_id = dnn_detect_get_label_id(ctx->nb_classes, cell_w * cell_h, + detection_boxes_data + cy * cell_w + cx + 5 * cell_w * cell_h); + conf = conf * detection_boxes_data[cy * cell_w + cx + (label_id + 5) * cell_w * cell_h]; + + bbox = av_mallocz(sizeof(*bbox)); + if (!bbox) + return AVERROR(ENOMEM); + + bbox->w = exp(w) * anchors[box_id * 2] * frame->width / scale_w; + bbox->h = exp(h) * anchors[box_id * 2 + 1] * frame->height / scale_h; + bbox->x = (cx + x) / cell_w * frame->width - bbox->w / 2; + bbox->y = (cy + y) / cell_h * frame->height - bbox->h / 2; + bbox->detect_confidence = av_make_q((int)(conf * 10000), 10000); + if (ctx->labels && label_id < ctx->label_count) { + av_strlcpy(bbox->detect_label, ctx->labels[label_id], sizeof(bbox->detect_label)); + } else { + snprintf(bbox->detect_label, sizeof(bbox->detect_label), "%d", label_id); + } + + if (av_fifo_write(ctx->bboxes_fifo, &bbox, 1) < 0) { + av_freep(&bbox); + return AVERROR(ENOMEM); + } + } + } + return 0; +} + +static int dnn_detect_fill_side_data(AVFrame *frame, AVFilterContext *filter_ctx) +{ + DnnDetectContext *ctx = filter_ctx->priv; + float conf_threshold = ctx->confidence; + AVDetectionBBox *bbox; + int nb_bboxes = 0; + AVDetectionBBoxHeader *header; + if (av_fifo_can_read(ctx->bboxes_fifo) == 0) { + av_log(filter_ctx, AV_LOG_VERBOSE, "nothing detected in this frame.\n"); + return 0; + } + + /* remove overlap bboxes */ + for (int i = 0; i < av_fifo_can_read(ctx->bboxes_fifo); i++){ + av_fifo_peek(ctx->bboxes_fifo, &bbox, 1, i); + for (int j = 0; j < av_fifo_can_read(ctx->bboxes_fifo); j++) { + AVDetectionBBox *overlap_bbox; + av_fifo_peek(ctx->bboxes_fifo, &overlap_bbox, 1, j); + if (!strcmp(bbox->detect_label, overlap_bbox->detect_label) && + av_cmp_q(bbox->detect_confidence, overlap_bbox->detect_confidence) < 0 && + dnn_detect_IOU(bbox, overlap_bbox) >= conf_threshold) { + bbox->classify_count = -1; // bad result + nb_bboxes++; + break; + } + } + } + nb_bboxes = av_fifo_can_read(ctx->bboxes_fifo) - nb_bboxes; + header = av_detection_bbox_create_side_data(frame, nb_bboxes); + if (!header) { + av_log(filter_ctx, AV_LOG_ERROR, "failed to create side data with %d bounding boxes\n", nb_bboxes); + return -1; + } + av_strlcpy(header->source, ctx->dnnctx.model_filename, sizeof(header->source)); + + while(av_fifo_can_read(ctx->bboxes_fifo)) { + AVDetectionBBox *candidate_bbox; + av_fifo_read(ctx->bboxes_fifo, &candidate_bbox, 1); + + if (nb_bboxes > 0 && candidate_bbox->classify_count != -1) { + bbox = av_get_detection_bbox(header, header->nb_bboxes - nb_bboxes); + memcpy(bbox, candidate_bbox, sizeof(*bbox)); + nb_bboxes--; + } + av_freep(&candidate_bbox); + } + return 0; +} + +static int dnn_detect_post_proc_yolo(AVFrame *frame, DNNData *output, AVFilterContext *filter_ctx) +{ + int ret = 0; + ret = dnn_detect_parse_yolo_output(frame, output, 0, filter_ctx); + if (ret < 0) + return ret; + ret = dnn_detect_fill_side_data(frame, filter_ctx); + if (ret < 0) + return ret; + return 0; +} + static int dnn_detect_post_proc_ssd(AVFrame *frame, DNNData *output, AVFilterContext *filter_ctx) { DnnDetectContext *ctx = filter_ctx->priv; @@ -158,6 +376,10 @@ static int dnn_detect_post_proc_ov(AVFrame *frame, DNNData *output, AVFilterCont if (ret < 0) return ret; break; + case DDMT_YOLOV1V2: + ret = dnn_detect_post_proc_yolo(frame, output, filter_ctx); + if (ret < 0) + return ret; } return 0; @@ -356,11 +578,22 @@ static av_cold int dnn_detect_init(AVFilterContext *context) ret = check_output_nb(ctx, dnn_ctx->backend_type, dnn_ctx->nb_outputs); if (ret < 0) return ret; + ctx->bboxes_fifo = av_fifo_alloc2(1, sizeof(AVDetectionBBox *), AV_FIFO_FLAG_AUTO_GROW); + if (!ctx->bboxes_fifo) + return AVERROR(ENOMEM); ff_dnn_set_detect_post_proc(&ctx->dnnctx, dnn_detect_post_proc); if (ctx->labels_filename) { return read_detect_label_file(context); } + if (ctx->anchors_str) { + ret = dnn_detect_parse_anchors(ctx->anchors_str, &ctx->anchors); + if (!ctx->anchors) { + av_log(context, AV_LOG_ERROR, "failed to parse anchors_str\n"); + return AVERROR(EINVAL); + } + ctx->nb_anchor = ret; + } return 0; } @@ -460,7 +693,14 @@ static int dnn_detect_activate(AVFilterContext *filter_ctx) static av_cold void dnn_detect_uninit(AVFilterContext *context) { DnnDetectContext *ctx = context->priv; + AVDetectionBBox *bbox; ff_dnn_uninit(&ctx->dnnctx); + while(av_fifo_can_read(ctx->bboxes_fifo)) { + av_fifo_read(ctx->bboxes_fifo, &bbox, 1); + av_freep(&bbox); + } + av_fifo_freep2(&ctx->bboxes_fifo); + av_freep(&ctx->anchors); free_detect_labels(ctx); }