diff --git a/demux/demux.c b/demux/demux.c
index 2b58ca00cf..5aca3c7aa6 100644
--- a/demux/demux.c
+++ b/demux/demux.c
@@ -349,6 +349,7 @@ struct demux_stream {
     int64_t back_restart_pos;
     double back_restart_dts;
     bool back_restart_eof; // restart position is at EOF; overrides pos/dts
+    bool back_restart_next; // restart before next keyframe; overrides above
     bool back_restarting;   // searching keyframe before restart pos
     // Current PTS lower bound for back demuxing.
     double back_seek_pos;
@@ -756,6 +757,7 @@ static void ds_clear_reader_state(struct demux_stream *ds,
         ds->back_restart_pos = -1;
         ds->back_restart_dts = MP_NOPTS_VALUE;
         ds->back_restart_eof = false;
+        ds->back_restart_next = false;
         ds->back_restarting = false;
         ds->back_seek_pos = MP_NOPTS_VALUE;
         ds->back_resume_pos = -1;
@@ -1286,6 +1288,30 @@ static void find_backward_restart_pos(struct demux_stream *ds)
     // If this is NULL, look for EOF (resume from very last keyframe).
     struct demux_packet *back_restart = NULL;
 
+    if (ds->back_restart_next) {
+        // Initial state. Switch to one of the other modi.
+
+        for (struct demux_packet *cur = first; cur; cur = cur->next) {
+            // Restart for next keyframe after reader_head.
+            if (cur != first && cur->keyframe) {
+                ds->back_restart_dts = cur->dts;
+                ds->back_restart_pos = cur->pos;
+                ds->back_restart_eof = false;
+                ds->back_restart_next = false;
+                break;
+            }
+        }
+
+        if (ds->back_restart_next && ds->eof) {
+            // Restart from end if nothing was found.
+            ds->back_restart_eof = true;
+            ds->back_restart_next = false;
+        }
+
+        if (ds->back_restart_next)
+            return;
+    }
+
     if (ds->back_restart_eof) {
         // We're trying to find EOF (without discarding packets). Only continue
         // if we really reach EOF.
@@ -1423,7 +1449,7 @@ resume_earlier:
         if (ds2 == ds || !ds2->eager)
             continue;
 
-        if (!ds2->reader_head && !ds2->back_resuming && !ds2->back_restarting) {
+        if (ds2->back_restarting && ds2->back_restart_next) {
             MP_VERBOSE(in, "delaying stream %d for %d\n", ds->index, ds2->index);
             return;
         }
@@ -1476,6 +1502,8 @@ static void step_backwards(struct demux_stream *ds)
     assert(!ds->back_restarting);
     ds->back_restarting = true;
 
+    ds->back_restart_next = false;
+
     // No valid restart pos, but EOF reached -> find last restart pos before EOF.
     ds->back_restart_eof = ds->back_restart_dts == MP_NOPTS_VALUE &&
                            ds->back_restart_pos < 0 &&
@@ -2237,6 +2265,12 @@ static int dequeue_packet(struct demux_stream *ds, struct demux_packet **res)
     if (in->blocked)
         return 0;
 
+    if (ds->eager) {
+        in->reading = true; // enable readahead
+        in->eof = false; // force retry
+        pthread_cond_signal(&in->wakeup); // possibly read more
+    }
+
     if (ds->back_resuming || ds->back_restarting) {
         assert(in->back_demuxing);
         return 0;
@@ -2255,12 +2289,6 @@ static int dequeue_packet(struct demux_stream *ds, struct demux_packet **res)
         return 1;
     }
 
-    if (ds->eager) {
-        in->reading = true; // enable readahead
-        in->eof = false; // force retry
-        pthread_cond_signal(&in->wakeup); // possibly read more
-    }
-
     bool eof = !ds->reader_head && ds->eof;
 
     if (in->back_demuxing) {
@@ -3270,12 +3298,22 @@ static bool queue_seek(struct demux_internal *in, double seek_pts, int flags,
     for (int n = 0; n < in->num_streams; n++) {
         struct demux_stream *ds = in->streams[n]->ds;
 
-        if (in->back_demuxing && clear_back_state)
+        if (in->back_demuxing && clear_back_state) {
             ds->back_seek_pos = seek_pts;
+            ds->back_restarting = ds->eager;
+            ds->back_restart_next = true;
+        }
 
         wakeup_ds(ds);
     }
 
+    if (in->back_demuxing) {
+        // Process possibly cached packets. Separate from the loop above, since
+        // all flags must be set on all streams before this function is called.
+        for (int n = 0; n < in->num_streams; n++)
+            back_demux_see_packets(in->streams[n]->ds);
+    }
+
     if (!in->threading && in->seeking)
         execute_seek(in);