osu/osu.Game/Database/MemoryCachingComponent.cs
2023-06-09 19:00:05 +09:00

113 lines
4.0 KiB
C#

// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System;
using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Tasks;
using osu.Framework.Extensions.TypeExtensions;
using osu.Framework.Graphics;
using osu.Framework.Statistics;
namespace osu.Game.Database
{
/// <summary>
/// A component which performs lookups (or calculations) and caches the results.
/// Currently not persisted between game sessions.
/// </summary>
public abstract partial class MemoryCachingComponent<TLookup, TValue> : Component
where TLookup : notnull
{
private readonly ConcurrentDictionary<TLookup, TValue?> cache = new ConcurrentDictionary<TLookup, TValue?>();
private readonly GlobalStatistic<MemoryCachingStatistics> statistics;
protected virtual bool CacheNullValues => true;
protected MemoryCachingComponent()
{
statistics = GlobalStatistics.Get<MemoryCachingStatistics>(nameof(MemoryCachingComponent<TLookup, TValue>), GetType().ReadableName());
statistics.Value = new MemoryCachingStatistics();
}
/// <summary>
/// Retrieve the cached value for the given lookup.
/// </summary>
/// <param name="lookup">The lookup to retrieve.</param>
/// <param name="token">An optional <see cref="CancellationToken"/> to cancel the operation.</param>
protected async Task<TValue?> GetAsync(TLookup lookup, CancellationToken token = default)
{
if (CheckExists(lookup, out TValue? existing))
{
statistics.Value.HitCount++;
return existing;
}
var computed = await ComputeValueAsync(lookup, token).ConfigureAwait(false);
statistics.Value.MissCount++;
if (computed != null || CacheNullValues)
{
cache[lookup] = computed;
statistics.Value.Usage = cache.Count;
}
return computed;
}
/// <summary>
/// Invalidate all entries matching a provided predicate.
/// </summary>
/// <param name="matchKeyPredicate">The predicate to decide which keys should be invalidated.</param>
protected void Invalidate(Func<TLookup, bool> matchKeyPredicate)
{
foreach (var kvp in cache)
{
if (matchKeyPredicate(kvp.Key))
cache.TryRemove(kvp.Key, out _);
}
statistics.Value.Usage = cache.Count;
}
protected bool CheckExists(TLookup lookup, [MaybeNullWhen(false)] out TValue value) =>
cache.TryGetValue(lookup, out value);
/// <summary>
/// Called on cache miss to compute the value for the specified lookup.
/// </summary>
/// <param name="lookup">The lookup to retrieve.</param>
/// <param name="token">An optional <see cref="CancellationToken"/> to cancel the operation.</param>
/// <returns>The computed value.</returns>
protected abstract Task<TValue?> ComputeValueAsync(TLookup lookup, CancellationToken token = default);
private class MemoryCachingStatistics
{
/// <summary>
/// Total number of cache hits.
/// </summary>
public int HitCount;
/// <summary>
/// Total number of cache misses.
/// </summary>
public int MissCount;
/// <summary>
/// Total number of cached entities.
/// </summary>
public int Usage;
public override string ToString()
{
int totalAccesses = HitCount + MissCount;
double hitRate = totalAccesses == 0 ? 0 : (double)HitCount / totalAccesses;
return $"i:{Usage} h:{HitCount} m:{MissCount} {hitRate:0%}";
}
}
}
}