osu/osu.Game/Database/OnlineLookupCache.cs

174 lines
6.6 KiB
C#

// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
#nullable disable
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using JetBrains.Annotations;
using osu.Framework.Allocation;
using osu.Framework.Extensions;
using osu.Game.Online.API;
namespace osu.Game.Database
{
public abstract partial class OnlineLookupCache<TLookup, TValue, TRequest> : MemoryCachingComponent<TLookup, TValue>
where TLookup : IEquatable<TLookup>
where TValue : class, IHasOnlineID<TLookup>
where TRequest : APIRequest
{
[Resolved]
private IAPIProvider api { get; set; }
/// <summary>
/// Creates an <see cref="APIRequest"/> to retrieve the values for a given collection of <typeparamref name="TLookup"/>s.
/// </summary>
/// <param name="ids">The IDs to perform the lookup with.</param>
protected abstract TRequest CreateRequest(IEnumerable<TLookup> ids);
/// <summary>
/// Retrieves a list of <typeparamref name="TValue"/>s from a successful <typeparamref name="TRequest"/> created by <see cref="CreateRequest"/>.
/// </summary>
[CanBeNull]
protected abstract IEnumerable<TValue> RetrieveResults(TRequest request);
/// <summary>
/// Perform a lookup using the specified <paramref name="id"/>, populating a <typeparamref name="TValue"/>.
/// </summary>
/// <param name="id">The ID to lookup.</param>
/// <param name="token">An optional cancellation token.</param>
/// <returns>The populated <typeparamref name="TValue"/>, or null if the value does not exist or the request could not be satisfied.</returns>
[ItemCanBeNull]
protected Task<TValue> LookupAsync(TLookup id, CancellationToken token = default) => GetAsync(id, token);
/// <summary>
/// Perform an API lookup on the specified <paramref name="ids"/>, populating a <typeparamref name="TValue"/>.
/// </summary>
/// <param name="ids">The IDs to lookup.</param>
/// <param name="token">An optional cancellation token.</param>
/// <returns>The populated values. May include null results for failed retrievals.</returns>
protected Task<TValue[]> LookupAsync(TLookup[] ids, CancellationToken token = default)
{
var lookupTasks = new List<Task<TValue>>();
foreach (var id in ids)
{
lookupTasks.Add(LookupAsync(id, token).ContinueWith(task =>
{
if (!task.IsCompletedSuccessfully)
return null;
return task.GetResultSafely();
}, token));
}
return Task.WhenAll(lookupTasks);
}
// cannot be sealed due to test usages (see TestUserLookupCache).
protected override async Task<TValue> ComputeValueAsync(TLookup lookup, CancellationToken token = default)
=> await queryValue(lookup).ConfigureAwait(false);
private readonly Queue<(TLookup id, TaskCompletionSource<TValue>)> pendingTasks = new Queue<(TLookup, TaskCompletionSource<TValue>)>();
private Task pendingRequestTask;
private readonly object taskAssignmentLock = new object();
private Task<TValue> queryValue(TLookup id)
{
lock (taskAssignmentLock)
{
var tcs = new TaskCompletionSource<TValue>();
// Add to the queue.
pendingTasks.Enqueue((id, tcs));
// Create a request task if there's not already one.
if (pendingRequestTask == null)
createNewTask();
return tcs.Task;
}
}
private async Task performLookup()
{
// contains at most 50 unique IDs from tasks, which is used to perform the lookup.
var nextTaskBatch = new Dictionary<TLookup, List<TaskCompletionSource<TValue>>>();
// Grab at most 50 unique IDs from the queue.
lock (taskAssignmentLock)
{
while (pendingTasks.Count > 0 && nextTaskBatch.Count < 50)
{
(TLookup id, TaskCompletionSource<TValue> task) next = pendingTasks.Dequeue();
// Perform a secondary check for existence, in case the value was queried in a previous batch.
if (CheckExists(next.id, out var existing))
next.task.SetResult(existing);
else
{
if (nextTaskBatch.TryGetValue(next.id, out var tasks))
tasks.Add(next.task);
else
nextTaskBatch[next.id] = new List<TaskCompletionSource<TValue>> { next.task };
}
}
}
if (nextTaskBatch.Count == 0)
{
finishPendingTask();
return;
}
// Query the values.
var request = CreateRequest(nextTaskBatch.Keys.ToArray());
// rather than queueing, we maintain our own single-threaded request stream.
// todo: we probably want retry logic here.
await api.PerformAsync(request).ConfigureAwait(false);
finishPendingTask();
var foundValues = RetrieveResults(request);
if (foundValues != null)
{
foreach (var value in foundValues)
{
if (nextTaskBatch.TryGetValue(value.OnlineID, out var tasks))
{
foreach (var task in tasks)
task.SetResult(value);
nextTaskBatch.Remove(value.OnlineID);
}
}
}
// if any tasks remain which were not satisfied, return null.
foreach (var tasks in nextTaskBatch.Values)
{
foreach (var task in tasks)
task.SetResult(null);
}
}
private void finishPendingTask()
{
// Create a new request task if there's still more values to query.
lock (taskAssignmentLock)
{
pendingRequestTask = null;
if (pendingTasks.Count > 0)
createNewTask();
}
}
private void createNewTask() => pendingRequestTask = Task.Run(performLookup);
}
}