diff --git a/libavcodec/cbs_av1.c b/libavcodec/cbs_av1.c index 154d9156cf..45e1288a51 100644 --- a/libavcodec/cbs_av1.c +++ b/libavcodec/cbs_av1.c @@ -1058,15 +1058,31 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx, AV1RawTileData *td; size_t header_size; int err, start_pos, end_pos, data_pos; + CodedBitstreamAV1Context av1ctx; // OBUs in the normal bitstream format must contain a size field // in every OBU (in annex B it is optional, but we don't support // writing that). obu->header.obu_has_size_field = 1; + av1ctx = *priv; + + if (priv->sequence_header_ref) { + av1ctx.sequence_header_ref = av_buffer_ref(priv->sequence_header_ref); + if (!av1ctx.sequence_header_ref) + return AVERROR(ENOMEM); + } + + if (priv->frame_header_ref) { + av1ctx.frame_header_ref = av_buffer_ref(priv->frame_header_ref); + if (!av1ctx.frame_header_ref) { + err = AVERROR(ENOMEM); + goto error; + } + } err = cbs_av1_write_obu_header(ctx, pbc, &obu->header); if (err < 0) - return err; + goto error; if (obu->header.obu_has_size_field) { pbc_tmp = *pbc; @@ -1084,18 +1100,21 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx, err = cbs_av1_write_sequence_header_obu(ctx, pbc, &obu->obu.sequence_header); if (err < 0) - return err; + goto error; av_buffer_unref(&priv->sequence_header_ref); priv->sequence_header = NULL; err = ff_cbs_make_unit_refcounted(ctx, unit); if (err < 0) - return err; + goto error; priv->sequence_header_ref = av_buffer_ref(unit->content_ref); - if (!priv->sequence_header_ref) - return AVERROR(ENOMEM); + if (!priv->sequence_header_ref) { + err = AVERROR(ENOMEM); + goto error; + } + priv->sequence_header = &obu->obu.sequence_header; } break; @@ -1103,7 +1122,7 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx, { err = cbs_av1_write_temporal_delimiter_obu(ctx, pbc); if (err < 0) - return err; + goto error; } break; case AV1_OBU_FRAME_HEADER: @@ -1115,7 +1134,7 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx, AV1_OBU_REDUNDANT_FRAME_HEADER, NULL); if (err < 0) - return err; + goto error; } break; case AV1_OBU_TILE_GROUP: @@ -1123,7 +1142,7 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx, err = cbs_av1_write_tile_group_obu(ctx, pbc, &obu->obu.tile_group); if (err < 0) - return err; + goto error; td = &obu->obu.tile_group.tile_data; } @@ -1132,7 +1151,7 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx, { err = cbs_av1_write_frame_obu(ctx, pbc, &obu->obu.frame, NULL); if (err < 0) - return err; + goto error; td = &obu->obu.frame.tile_group.tile_data; } @@ -1141,7 +1160,7 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx, { err = cbs_av1_write_tile_list_obu(ctx, pbc, &obu->obu.tile_list); if (err < 0) - return err; + goto error; td = &obu->obu.tile_list.tile_data; } @@ -1150,18 +1169,19 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx, { err = cbs_av1_write_metadata_obu(ctx, pbc, &obu->obu.metadata); if (err < 0) - return err; + goto error; } break; case AV1_OBU_PADDING: { err = cbs_av1_write_padding_obu(ctx, pbc, &obu->obu.padding); if (err < 0) - return err; + goto error; } break; default: - return AVERROR(ENOSYS); + err = AVERROR(ENOSYS); + goto error; } end_pos = put_bits_count(pbc); @@ -1172,7 +1192,7 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx, // Add trailing bits and recalculate. err = cbs_av1_write_trailing_bits(ctx, pbc, 8 - end_pos % 8); if (err < 0) - return err; + goto error; end_pos = put_bits_count(pbc); obu->obu_size = header_size = (end_pos - start_pos + 7) / 8; } else { @@ -1190,14 +1210,19 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx, *pbc = pbc_tmp; err = cbs_av1_write_leb128(ctx, pbc, "obu_size", obu->obu_size); if (err < 0) - return err; + goto error; data_pos = put_bits_count(pbc) / 8; flush_put_bits(pbc); av_assert0(data_pos <= start_pos); - if (8 * obu->obu_size > put_bits_left(pbc)) + if (8 * obu->obu_size > put_bits_left(pbc)) { + av_buffer_unref(&priv->sequence_header_ref); + av_buffer_unref(&priv->frame_header_ref); + *priv = av1ctx; + return AVERROR(ENOSPC); + } if (obu->obu_size > 0) { memmove(pbc->buf + data_pos, @@ -1213,8 +1238,13 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx, // OBU data must be byte-aligned. av_assert0(put_bits_count(pbc) % 8 == 0); + err = 0; - return 0; +error: + av_buffer_unref(&av1ctx.sequence_header_ref); + av_buffer_unref(&av1ctx.frame_header_ref); + + return err; } static int cbs_av1_assemble_fragment(CodedBitstreamContext *ctx,