Only write when writes occur

Also add finaliser logic for safety. Also better threading. Also more cleanup.
This commit is contained in:
Dean Herbert 2018-02-12 19:57:21 +09:00
parent edc3638175
commit 8b37fde15b
8 changed files with 86 additions and 50 deletions

View File

@ -172,7 +172,7 @@ public List<BeatmapSetInfo> Import(params string[] paths)
/// <param name="archive">The beatmap to be imported.</param>
public BeatmapSetInfo Import(ArchiveReader archive)
{
using ( contextFactory.GetForWrite()) // used to share a context for full import. keep in mind this will block all writes.
using (contextFactory.GetForWrite()) // used to share a context for full import. keep in mind this will block all writes.
{
// create a new set info (don't yet add to database)
var beatmapSet = createBeatmapSetInfo(archive);
@ -181,7 +181,7 @@ public BeatmapSetInfo Import(ArchiveReader archive)
var existingHashMatch = beatmaps.BeatmapSets.FirstOrDefault(b => b.Hash == beatmapSet.Hash);
if (existingHashMatch != null)
{
undelete(existingHashMatch);
Undelete(existingHashMatch);
return existingHashMatch;
}
@ -315,9 +315,9 @@ public void Download(BeatmapSetInfo beatmapSetInfo, bool noVideo = false)
/// <param name="beatmapSet">The beatmap set to delete.</param>
public void Delete(BeatmapSetInfo beatmapSet)
{
using (var db = contextFactory.GetForWrite())
using (var usage = contextFactory.GetForWrite())
{
var context = db.Context;
var context = usage.Context;
context.ChangeTracker.AutoDetectChangesEnabled = false;
@ -378,11 +378,16 @@ public void Undelete(BeatmapSetInfo beatmapSet)
if (beatmapSet.Protected)
return;
using (var db = contextFactory.GetForWrite())
using (var usage = contextFactory.GetForWrite())
{
db.Context.ChangeTracker.AutoDetectChangesEnabled = false;
undelete(beatmapSet);
db.Context.ChangeTracker.AutoDetectChangesEnabled = true;
usage.Context.ChangeTracker.AutoDetectChangesEnabled = false;
if (!beatmaps.Undelete(beatmapSet)) return;
if (!beatmapSet.Protected)
files.Reference(beatmapSet.Files.Select(f => f.FileInfo).ToArray());
usage.Context.ChangeTracker.AutoDetectChangesEnabled = true;
}
}
@ -398,21 +403,6 @@ public void Undelete(BeatmapSetInfo beatmapSet)
/// <param name="beatmap">The beatmap difficulty to restore.</param>
public void Restore(BeatmapInfo beatmap) => beatmaps.Restore(beatmap);
/// <summary>
/// Returns a <see cref="BeatmapSetInfo"/> to a usable state if it has previously been deleted but not yet purged.
/// Is a no-op for already usable beatmaps.
/// </summary>
/// <param name="beatmaps">The store to restore beatmaps from.</param>
/// <param name="files">The store to restore beatmap files from.</param>
/// <param name="beatmapSet">The beatmap to restore.</param>
private void undelete(BeatmapSetInfo beatmapSet)
{
if (!beatmaps.Undelete(beatmapSet)) return;
if (!beatmapSet.Protected)
files.Reference(beatmapSet.Files.Select(f => f.FileInfo).ToArray());
}
/// <summary>
/// Retrieve a <see cref="WorkingBeatmap"/> instance for the provided <see cref="BeatmapInfo"/>
/// </summary>

View File

@ -31,9 +31,9 @@ public BeatmapStore(DatabaseContextFactory factory)
/// <param name="beatmapSet">The beatmap to add.</param>
public void Add(BeatmapSetInfo beatmapSet)
{
using (var db = ContextFactory.GetForWrite())
using (var usage = ContextFactory.GetForWrite())
{
var context = db.Context;
var context = usage.Context;
foreach (var beatmap in beatmapSet.Beatmaps.Where(b => b.Metadata != null))
{
@ -48,6 +48,7 @@ public void Add(BeatmapSetInfo beatmapSet)
}
context.BeatmapSetInfo.Attach(beatmapSet);
BeatmapSetAdded?.Invoke(beatmapSet);
}
}
@ -73,11 +74,12 @@ public void Update(BeatmapSetInfo beatmapSet)
/// <returns>Whether the beatmap's <see cref="BeatmapSetInfo.DeletePending"/> was changed.</returns>
public bool Delete(BeatmapSetInfo beatmapSet)
{
using ( ContextFactory.GetForWrite())
using (ContextFactory.GetForWrite())
{
Refresh(ref beatmapSet, BeatmapSets);
if (beatmapSet.DeletePending) return false;
beatmapSet.DeletePending = true;
}
@ -92,11 +94,12 @@ public bool Delete(BeatmapSetInfo beatmapSet)
/// <returns>Whether the beatmap's <see cref="BeatmapSetInfo.DeletePending"/> was changed.</returns>
public bool Undelete(BeatmapSetInfo beatmapSet)
{
using ( ContextFactory.GetForWrite())
using (ContextFactory.GetForWrite())
{
Refresh(ref beatmapSet, BeatmapSets);
if (!beatmapSet.DeletePending) return false;
beatmapSet.DeletePending = false;
}
@ -116,6 +119,7 @@ public bool Hide(BeatmapInfo beatmap)
Refresh(ref beatmap, Beatmaps);
if (beatmap.Hidden) return false;
beatmap.Hidden = true;
BeatmapHidden?.Invoke(beatmap);
@ -136,6 +140,7 @@ public bool Restore(BeatmapInfo beatmap)
Refresh(ref beatmap, Beatmaps);
if (!beatmap.Hidden) return false;
beatmap.Hidden = false;
}
@ -155,7 +160,9 @@ public void Cleanup(Expression<Func<BeatmapSetInfo, bool>> query)
.Where(query)
.Include(s => s.Beatmaps).ThenInclude(b => b.Metadata)
.Include(s => s.Beatmaps).ThenInclude(b => b.BaseDifficulty)
.Include(s => s.Metadata);
.Include(s => s.Metadata).ToList();
if (!purgeable.Any()) return;
// metadata is M-N so we can't rely on cascades
context.BeatmapMetadata.RemoveRange(purgeable.Select(s => s.Metadata));

View File

@ -34,10 +34,7 @@ protected virtual void Refresh<T>(ref T obj, IEnumerable<T> lookupSource = null)
var id = obj.ID;
var foundObject = lookupSource?.SingleOrDefault(t => t.ID == id) ?? context.Find<T>(id);
if (foundObject != null)
{
obj = foundObject;
context.Entry(obj).Reload();
}
else
context.Add(obj);
}

View File

@ -1,6 +1,7 @@
// Copyright (c) 2007-2018 ppy Pty Ltd <contact@ppy.sh>.
// Licensed under the MIT Licence - https://raw.githubusercontent.com/ppy/osu/master/LICENCE
using System.Diagnostics;
using System.Threading;
using osu.Framework.Platform;
@ -18,6 +19,7 @@ public class DatabaseContextFactory
private OsuDbContext writeContext;
private bool currentWriteDidWrite;
private volatile int currentWriteUsages;
public DatabaseContextFactory(GameHost host)
@ -38,24 +40,41 @@ public DatabaseContextFactory(GameHost host)
/// <returns>A usage containing a usable context.</returns>
public DatabaseWriteUsage GetForWrite()
{
lock (writeLock)
{
var usage = new DatabaseWriteUsage(writeContext ?? (writeContext = threadContexts.Value), usageCompleted);
Interlocked.Increment(ref currentWriteUsages);
return usage;
}
Monitor.Enter(writeLock);
Trace.Assert(currentWriteUsages == 0, "Database writes in a bad state");
Interlocked.Increment(ref currentWriteUsages);
return new DatabaseWriteUsage(writeContext ?? (writeContext = threadContexts.Value), usageCompleted);
}
private void usageCompleted(DatabaseWriteUsage usage)
{
int usages = Interlocked.Decrement(ref currentWriteUsages);
if (usages == 0)
try
{
writeContext.Dispose();
currentWriteDidWrite |= usage.PerformedWrite;
if (usages > 0) return;
if (currentWriteDidWrite)
{
writeContext.Dispose();
currentWriteDidWrite = false;
// once all writes are complete, we want to refresh thread-specific contexts to make sure they don't have stale local caches.
recycleThreadContexts();
}
// always set to null (even when a write didn't occur) so we get the correct thread context on next write request.
writeContext = null;
// once all writes are complete, we want to refresh thread-specific contexts to make sure they don't have stale local caches.
recycleThreadContexts();
}
finally
{
Monitor.Exit(writeLock);
}
}

View File

@ -19,10 +19,28 @@ public DatabaseWriteUsage(OsuDbContext context, Action<DatabaseWriteUsage> onCom
usageCompleted = onCompleted;
}
public bool PerformedWrite { get; private set; }
private bool isDisposed;
protected void Dispose(bool disposing)
{
if (isDisposed) return;
isDisposed = true;
PerformedWrite |= Context.SaveChanges(transaction) > 0;
usageCompleted?.Invoke(this);
}
public void Dispose()
{
Context.SaveChanges(transaction);
usageCompleted?.Invoke(this);
Dispose(true);
GC.SuppressFinalize(this);
}
~DatabaseWriteUsage()
{
Dispose(false);
}
}
}

View File

@ -111,7 +111,7 @@ public IDbContextTransaction BeginTransaction()
public int SaveChanges(IDbContextTransaction transaction = null)
{
var ret = base.SaveChanges();
transaction?.Commit();
if (ret > 0) transaction?.Commit();
return ret;
}

View File

@ -30,11 +30,9 @@ public FileInfo Add(Stream data, bool reference = true)
{
using (var usage = ContextFactory.GetForWrite())
{
var context = usage.Context;
string hash = data.ComputeSHA2Hash();
var existing = context.FileInfo.FirstOrDefault(f => f.Hash == hash);
var existing = usage.Context.FileInfo.FirstOrDefault(f => f.Hash == hash);
var info = existing ?? new FileInfo { Hash = hash };
@ -60,6 +58,8 @@ public FileInfo Add(Stream data, bool reference = true)
public void Reference(params FileInfo[] files)
{
if (files.Length == 0) return;
using (var usage = ContextFactory.GetForWrite())
{
var context = usage.Context;
@ -75,9 +75,12 @@ public void Reference(params FileInfo[] files)
public void Dereference(params FileInfo[] files)
{
if (files.Length == 0) return;
using (var usage = ContextFactory.GetForWrite())
{
var context = usage.Context;
foreach (var f in files.GroupBy(f => f.ID))
{
var refetch = context.FileInfo.Find(f.Key);

View File

@ -36,8 +36,6 @@ private void insertDefaults(IEnumerable<KeyBinding> defaults, int? rulesetId = n
{
using (var usage = ContextFactory.GetForWrite())
{
var context = usage.Context;
// compare counts in database vs defaults
foreach (var group in defaults.GroupBy(k => k.Action))
{
@ -49,7 +47,7 @@ private void insertDefaults(IEnumerable<KeyBinding> defaults, int? rulesetId = n
foreach (var insertable in group.Skip(count).Take(aimCount - count))
// insert any defaults which are missing.
context.DatabasedKeyBinding.Add(new DatabasedKeyBinding
usage.Context.DatabasedKeyBinding.Add(new DatabasedKeyBinding
{
KeyCombination = insertable.KeyCombination,
Action = insertable.Action,
@ -75,6 +73,10 @@ public void Update(KeyBinding keyBinding)
{
var dbKeyBinding = (DatabasedKeyBinding)keyBinding;
Refresh(ref dbKeyBinding);
if (dbKeyBinding.KeyCombination.Equals(keyBinding.KeyCombination))
return;
dbKeyBinding.KeyCombination = keyBinding.KeyCombination;
}