From 8c3b0a316737eb359c4d00e03001ab7167ed0633 Mon Sep 17 00:00:00 2001 From: smoogipoo Date: Tue, 26 Jan 2021 22:47:37 +0900 Subject: [PATCH] Fix TaskChain performing the action in-line, add test --- osu.Game.Tests/NonVisual/TaskChainTest.cs | 83 +++++++++++++++++++++++ osu.Game/Utils/TaskChain.cs | 28 ++++++-- 2 files changed, 106 insertions(+), 5 deletions(-) create mode 100644 osu.Game.Tests/NonVisual/TaskChainTest.cs diff --git a/osu.Game.Tests/NonVisual/TaskChainTest.cs b/osu.Game.Tests/NonVisual/TaskChainTest.cs new file mode 100644 index 0000000000..d561fb4c1b --- /dev/null +++ b/osu.Game.Tests/NonVisual/TaskChainTest.cs @@ -0,0 +1,83 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using System.Threading; +using System.Threading.Tasks; +using NUnit.Framework; +using osu.Game.Utils; + +namespace osu.Game.Tests.NonVisual +{ + [TestFixture] + public class TaskChainTest + { + private TaskChain taskChain; + private int currentTask; + + [SetUp] + public void Setup() + { + taskChain = new TaskChain(); + currentTask = 0; + } + + [Test] + public async Task TestChainedTasksRunSequentially() + { + var task1 = addTask(); + var task2 = addTask(); + var task3 = addTask(); + + task3.mutex.Set(); + task2.mutex.Set(); + task1.mutex.Set(); + + await Task.WhenAll(task1.task, task2.task, task3.task); + + Assert.That(task1.task.Result, Is.EqualTo(1)); + Assert.That(task2.task.Result, Is.EqualTo(2)); + Assert.That(task3.task.Result, Is.EqualTo(3)); + } + + [Test] + public async Task TestChainedTaskWithIntermediateCancelRunsInSequence() + { + var task1 = addTask(); + var task2 = addTask(); + var task3 = addTask(); + + // Cancel task2, allow task3 to complete. + task2.cancellation.Cancel(); + task2.mutex.Set(); + task3.mutex.Set(); + + // Allow task3 to potentially complete. + Thread.Sleep(1000); + + // Allow task1 to complete. + task1.mutex.Set(); + + // Wait on both tasks. + await Task.WhenAll(task1.task, task3.task); + + Assert.That(task1.task.Result, Is.EqualTo(1)); + Assert.That(task2.task.IsCompleted, Is.False); + Assert.That(task3.task.Result, Is.EqualTo(2)); + } + + private (Task task, ManualResetEventSlim mutex, CancellationTokenSource cancellation) addTask() + { + var mutex = new ManualResetEventSlim(false); + var cancellationSource = new CancellationTokenSource(); + var completionSource = new TaskCompletionSource(); + + taskChain.Add(() => + { + mutex.Wait(CancellationToken.None); + completionSource.SetResult(Interlocked.Increment(ref currentTask)); + }, cancellationSource.Token); + + return (completionSource.Task, mutex, cancellationSource); + } + } +} diff --git a/osu.Game/Utils/TaskChain.cs b/osu.Game/Utils/TaskChain.cs index 64d523bd3d..2bc2c00e28 100644 --- a/osu.Game/Utils/TaskChain.cs +++ b/osu.Game/Utils/TaskChain.cs @@ -4,6 +4,7 @@ #nullable enable using System; +using System.Threading; using System.Threading.Tasks; namespace osu.Game.Utils @@ -19,15 +20,32 @@ public class TaskChain /// /// Adds a new task to the end of this . /// - /// The task creation function. + /// The action to be executed. + /// The for this task. Does not affect further tasks in the chain. /// The awaitable . - public Task Add(Func taskFunc) + public Task Add(Action action, CancellationToken cancellationToken = default) { lock (currentTaskLock) { - currentTask = currentTask == null - ? taskFunc() - : currentTask.ContinueWith(_ => taskFunc()).Unwrap(); + // Note: Attaching the cancellation token to the continuation could lead to re-ordering of tasks in the chain. + // Therefore, the cancellation token is not used to cancel the continuation but only the run of each task. + if (currentTask == null) + { + currentTask = Task.Run(() => + { + cancellationToken.ThrowIfCancellationRequested(); + action(); + }, CancellationToken.None); + } + else + { + currentTask = currentTask.ContinueWith(_ => + { + cancellationToken.ThrowIfCancellationRequested(); + action(); + }, CancellationToken.None); + } + return currentTask; } }