lavfi/dnn_backend_tf: Separate function for filling RequestItem

This commit rearranges the existing code to create separate function
for filling request with execution data.

Signed-off-by: Shubhanshu Saxena <shubhanshu.e01@gmail.com>
This commit is contained in:
Shubhanshu Saxena 2021-07-05 16:00:56 +05:30 committed by Guo Yejun
parent 08d8b3b631
commit b849228ae0
1 changed files with 80 additions and 57 deletions

View File

@ -839,20 +839,16 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, DNNFunctionType func_
return model; return model;
} }
static DNNReturnType execute_model_tf(TFRequestItem *request, Queue *inference_queue) static DNNReturnType fill_model_input_tf(TFModel *tf_model, TFRequestItem *request) {
{ DNNData input;
TFModel *tf_model;
TFContext *ctx;
TFInferRequest *infer_request;
InferenceItem *inference; InferenceItem *inference;
TaskItem *task; TaskItem *task;
DNNData input, *outputs; TFInferRequest *infer_request;
TFContext *ctx = &tf_model->ctx;
inference = ff_queue_pop_front(inference_queue); inference = ff_queue_pop_front(tf_model->inference_queue);
av_assert0(inference); av_assert0(inference);
task = inference->task; task = inference->task;
tf_model = task->model;
ctx = &tf_model->ctx;
request->inference = inference; request->inference = inference;
if (get_input_tf(tf_model, &input, task->input_name) != DNN_SUCCESS) if (get_input_tf(tf_model, &input, task->input_name) != DNN_SUCCESS)
@ -916,6 +912,32 @@ static DNNReturnType execute_model_tf(TFRequestItem *request, Queue *inference_q
infer_request->tf_outputs[i].index = 0; infer_request->tf_outputs[i].index = 0;
} }
return DNN_SUCCESS;
}
static DNNReturnType execute_model_tf(TFRequestItem *request, Queue *inference_queue)
{
TFModel *tf_model;
TFContext *ctx;
TFInferRequest *infer_request;
InferenceItem *inference;
TaskItem *task;
DNNData *outputs;
inference = ff_queue_peek_front(inference_queue);
task = inference->task;
tf_model = task->model;
ctx = &tf_model->ctx;
if (task->async) {
avpriv_report_missing_feature(ctx, "Async execution not supported");
return DNN_ERROR;
} else {
if (fill_model_input_tf(tf_model, request) != DNN_SUCCESS) {
return DNN_ERROR;
}
infer_request = request->infer_request;
TF_SessionRun(tf_model->session, NULL, TF_SessionRun(tf_model->session, NULL,
infer_request->tf_input, &infer_request->input_tensor, 1, infer_request->tf_input, &infer_request->input_tensor, 1,
infer_request->tf_outputs, infer_request->output_tensors, infer_request->tf_outputs, infer_request->output_tensors,
@ -973,6 +995,7 @@ static DNNReturnType execute_model_tf(TFRequestItem *request, Queue *inference_q
av_freep(&outputs); av_freep(&outputs);
ff_safe_queue_push_back(tf_model->request_queue, request); ff_safe_queue_push_back(tf_model->request_queue, request);
return (task->inference_done == task->inference_todo) ? DNN_SUCCESS : DNN_ERROR; return (task->inference_done == task->inference_todo) ? DNN_SUCCESS : DNN_ERROR;
}
} }
DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNExecBaseParams *exec_params) DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNExecBaseParams *exec_params)