diff --git a/libavfilter/dnn/dnn_backend_tf.c b/libavfilter/dnn/dnn_backend_tf.c index 74fe06d6fb..552a9f2fa1 100644 --- a/libavfilter/dnn/dnn_backend_tf.c +++ b/libavfilter/dnn/dnn_backend_tf.c @@ -274,6 +274,7 @@ static int get_input_tf(void *model, DNNData *input, const char *input_name) TFModel *tf_model = model; TFContext *ctx = &tf_model->ctx; TF_Status *status; + TF_DataType dt; int64_t dims[4]; TF_Output tf_output; @@ -284,7 +285,18 @@ static int get_input_tf(void *model, DNNData *input, const char *input_name) } tf_output.index = 0; - input->dt = TF_OperationOutputType(tf_output); + dt = TF_OperationOutputType(tf_output); + switch (dt) { + case TF_FLOAT: + input->dt = DNN_FLOAT; + break; + case TF_UINT8: + input->dt = DNN_UINT8; + break; + default: + av_log(ctx, AV_LOG_ERROR, "Unsupported output type %d in model\n", dt); + return AVERROR(EINVAL); + } input->order = DCO_RGB; status = TF_NewStatus();