1
0
mirror of https://git.ffmpeg.org/ffmpeg.git synced 2025-03-25 04:19:05 +00:00

libavfilter/dnn: add more data type support for dnn model input

currently, only float is supported as model input, actually, there
are other data types, this patch adds uint8.

Signed-off-by: Guo, Yejun <yejun.guo@intel.com>
Signed-off-by: Pedro Arthur <bygrandao@gmail.com>
This commit is contained in:
Guo, Yejun 2019-04-25 10:14:42 +08:00 committed by Pedro Arthur
parent 25c1cd909f
commit c636dc9819
4 changed files with 39 additions and 7 deletions

View File

@ -24,8 +24,9 @@
*/ */
#include "dnn_backend_native.h" #include "dnn_backend_native.h"
#include "libavutil/avassert.h"
static DNNReturnType set_input_output_native(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output) static DNNReturnType set_input_output_native(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output)
{ {
ConvolutionalNetwork *network = (ConvolutionalNetwork *)model; ConvolutionalNetwork *network = (ConvolutionalNetwork *)model;
InputParams *input_params; InputParams *input_params;
@ -45,6 +46,7 @@ static DNNReturnType set_input_output_native(void *model, DNNData *input, const
if (input->data){ if (input->data){
av_freep(&input->data); av_freep(&input->data);
} }
av_assert0(input->dt == DNN_FLOAT);
network->layers[0].output = input->data = av_malloc(cur_height * cur_width * cur_channels * sizeof(float)); network->layers[0].output = input->data = av_malloc(cur_height * cur_width * cur_channels * sizeof(float));
if (!network->layers[0].output){ if (!network->layers[0].output){
return DNN_ERROR; return DNN_ERROR;

View File

@ -79,10 +79,31 @@ static TF_Buffer *read_graph(const char *model_filename)
return graph_buf; return graph_buf;
} }
static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output) static TF_Tensor *allocate_input_tensor(const DNNInputData *input)
{
TF_DataType dt;
size_t size;
int64_t input_dims[] = {1, input->height, input->width, input->channels};
switch (input->dt) {
case DNN_FLOAT:
dt = TF_FLOAT;
size = sizeof(float);
break;
case DNN_UINT8:
dt = TF_UINT8;
size = sizeof(char);
break;
default:
av_assert0(!"should not reach here");
}
return TF_AllocateTensor(dt, input_dims, 4,
input_dims[1] * input_dims[2] * input_dims[3] * size);
}
static DNNReturnType set_input_output_tf(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output)
{ {
TFModel *tf_model = (TFModel *)model; TFModel *tf_model = (TFModel *)model;
int64_t input_dims[] = {1, input->height, input->width, input->channels};
TF_SessionOptions *sess_opts; TF_SessionOptions *sess_opts;
const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init"); const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init");
@ -95,8 +116,7 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char
if (tf_model->input_tensor){ if (tf_model->input_tensor){
TF_DeleteTensor(tf_model->input_tensor); TF_DeleteTensor(tf_model->input_tensor);
} }
tf_model->input_tensor = TF_AllocateTensor(TF_FLOAT, input_dims, 4, tf_model->input_tensor = allocate_input_tensor(input);
input_dims[1] * input_dims[2] * input_dims[3] * sizeof(float));
if (!tf_model->input_tensor){ if (!tf_model->input_tensor){
return DNN_ERROR; return DNN_ERROR;
} }

View File

@ -32,6 +32,14 @@ typedef enum {DNN_SUCCESS, DNN_ERROR} DNNReturnType;
typedef enum {DNN_NATIVE, DNN_TF} DNNBackendType; typedef enum {DNN_NATIVE, DNN_TF} DNNBackendType;
typedef enum {DNN_FLOAT, DNN_UINT8} DNNDataType;
typedef struct DNNInputData{
void *data;
DNNDataType dt;
int width, height, channels;
} DNNInputData;
typedef struct DNNData{ typedef struct DNNData{
float *data; float *data;
int width, height, channels; int width, height, channels;
@ -42,7 +50,7 @@ typedef struct DNNModel{
void *model; void *model;
// Sets model input and output. // Sets model input and output.
// Should be called at least once before model execution. // Should be called at least once before model execution.
DNNReturnType (*set_input_output)(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output); DNNReturnType (*set_input_output)(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output);
} DNNModel; } DNNModel;
// Stores pointers to functions for loading, executing, freeing DNN models for one of the backends. // Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.

View File

@ -40,7 +40,8 @@ typedef struct SRContext {
DNNBackendType backend_type; DNNBackendType backend_type;
DNNModule *dnn_module; DNNModule *dnn_module;
DNNModel *model; DNNModel *model;
DNNData input, output; DNNInputData input;
DNNData output;
int scale_factor; int scale_factor;
struct SwsContext *sws_contexts[3]; struct SwsContext *sws_contexts[3];
int sws_slice_h, sws_input_linesize, sws_output_linesize; int sws_slice_h, sws_input_linesize, sws_output_linesize;
@ -86,6 +87,7 @@ static av_cold int init(AVFilterContext *context)
return AVERROR(EIO); return AVERROR(EIO);
} }
sr_context->input.dt = DNN_FLOAT;
sr_context->sws_contexts[0] = NULL; sr_context->sws_contexts[0] = NULL;
sr_context->sws_contexts[1] = NULL; sr_context->sws_contexts[1] = NULL;
sr_context->sws_contexts[2] = NULL; sr_context->sws_contexts[2] = NULL;