diff --git a/libavfilter/opencl.c b/libavfilter/opencl.c index 005ad089e2..37afc41f8b 100644 --- a/libavfilter/opencl.c +++ b/libavfilter/opencl.c @@ -42,11 +42,29 @@ int ff_opencl_filter_query_formats(AVFilterContext *avctx) return ff_set_common_formats(avctx, formats); } +static int opencl_filter_set_device(AVFilterContext *avctx, + AVBufferRef *device) +{ + OpenCLFilterContext *ctx = avctx->priv; + + av_buffer_unref(&ctx->device_ref); + + ctx->device_ref = av_buffer_ref(device); + if (!ctx->device_ref) + return AVERROR(ENOMEM); + + ctx->device = (AVHWDeviceContext*)ctx->device_ref->data; + ctx->hwctx = ctx->device->hwctx; + + return 0; +} + int ff_opencl_filter_config_input(AVFilterLink *inlink) { AVFilterContext *avctx = inlink->dst; OpenCLFilterContext *ctx = avctx->priv; AVHWFramesContext *input_frames; + int err; if (!inlink->hw_frames_ctx) { av_log(avctx, AV_LOG_ERROR, "OpenCL filtering requires a " @@ -59,15 +77,12 @@ int ff_opencl_filter_config_input(AVFilterLink *inlink) return 0; input_frames = (AVHWFramesContext*)inlink->hw_frames_ctx->data; - if (input_frames->format != AV_PIX_FMT_OPENCL) return AVERROR(EINVAL); - ctx->device_ref = av_buffer_ref(input_frames->device_ref); - if (!ctx->device_ref) - return AVERROR(ENOMEM); - ctx->device = input_frames->device_ctx; - ctx->hwctx = ctx->device->hwctx; + err = opencl_filter_set_device(avctx, input_frames->device_ref); + if (err < 0) + return err; // Default output parameters match input parameters. if (ctx->output_format == AV_PIX_FMT_NONE) @@ -90,6 +105,18 @@ int ff_opencl_filter_config_output(AVFilterLink *outlink) av_buffer_unref(&outlink->hw_frames_ctx); + if (!ctx->device_ref) { + if (!avctx->hw_device_ctx) { + av_log(avctx, AV_LOG_ERROR, "OpenCL filtering requires an " + "OpenCL device.\n"); + return AVERROR(EINVAL); + } + + err = opencl_filter_set_device(avctx, avctx->hw_device_ctx); + if (err < 0) + return err; + } + output_frames_ref = av_hwframe_ctx_alloc(ctx->device_ref); if (!output_frames_ref) { err = AVERROR(ENOMEM);