Fix TaskChain performing the action in-line, add test

This commit is contained in:
smoogipoo 2021-01-26 22:47:37 +09:00
parent c17774e23c
commit 8c3b0a3167
2 changed files with 106 additions and 5 deletions

View File

@ -0,0 +1,83 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System.Threading;
using System.Threading.Tasks;
using NUnit.Framework;
using osu.Game.Utils;
namespace osu.Game.Tests.NonVisual
{
[TestFixture]
public class TaskChainTest
{
private TaskChain taskChain;
private int currentTask;
[SetUp]
public void Setup()
{
taskChain = new TaskChain();
currentTask = 0;
}
[Test]
public async Task TestChainedTasksRunSequentially()
{
var task1 = addTask();
var task2 = addTask();
var task3 = addTask();
task3.mutex.Set();
task2.mutex.Set();
task1.mutex.Set();
await Task.WhenAll(task1.task, task2.task, task3.task);
Assert.That(task1.task.Result, Is.EqualTo(1));
Assert.That(task2.task.Result, Is.EqualTo(2));
Assert.That(task3.task.Result, Is.EqualTo(3));
}
[Test]
public async Task TestChainedTaskWithIntermediateCancelRunsInSequence()
{
var task1 = addTask();
var task2 = addTask();
var task3 = addTask();
// Cancel task2, allow task3 to complete.
task2.cancellation.Cancel();
task2.mutex.Set();
task3.mutex.Set();
// Allow task3 to potentially complete.
Thread.Sleep(1000);
// Allow task1 to complete.
task1.mutex.Set();
// Wait on both tasks.
await Task.WhenAll(task1.task, task3.task);
Assert.That(task1.task.Result, Is.EqualTo(1));
Assert.That(task2.task.IsCompleted, Is.False);
Assert.That(task3.task.Result, Is.EqualTo(2));
}
private (Task<int> task, ManualResetEventSlim mutex, CancellationTokenSource cancellation) addTask()
{
var mutex = new ManualResetEventSlim(false);
var cancellationSource = new CancellationTokenSource();
var completionSource = new TaskCompletionSource<int>();
taskChain.Add(() =>
{
mutex.Wait(CancellationToken.None);
completionSource.SetResult(Interlocked.Increment(ref currentTask));
}, cancellationSource.Token);
return (completionSource.Task, mutex, cancellationSource);
}
}
}

View File

@ -4,6 +4,7 @@
#nullable enable
using System;
using System.Threading;
using System.Threading.Tasks;
namespace osu.Game.Utils
@ -19,15 +20,32 @@ public class TaskChain
/// <summary>
/// Adds a new task to the end of this <see cref="TaskChain"/>.
/// </summary>
/// <param name="taskFunc">The task creation function.</param>
/// <param name="action">The action to be executed.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> for this task. Does not affect further tasks in the chain.</param>
/// <returns>The awaitable <see cref="Task"/>.</returns>
public Task Add(Func<Task> taskFunc)
public Task Add(Action action, CancellationToken cancellationToken = default)
{
lock (currentTaskLock)
{
currentTask = currentTask == null
? taskFunc()
: currentTask.ContinueWith(_ => taskFunc()).Unwrap();
// Note: Attaching the cancellation token to the continuation could lead to re-ordering of tasks in the chain.
// Therefore, the cancellation token is not used to cancel the continuation but only the run of each task.
if (currentTask == null)
{
currentTask = Task.Run(() =>
{
cancellationToken.ThrowIfCancellationRequested();
action();
}, CancellationToken.None);
}
else
{
currentTask = currentTask.ContinueWith(_ =>
{
cancellationToken.ThrowIfCancellationRequested();
action();
}, CancellationToken.None);
}
return currentTask;
}
}