Standardise context retrieval

This commit is contained in:
Dean Herbert 2017-10-17 15:50:42 +09:00
parent cd41862e3b
commit e487b6f82a
5 changed files with 101 additions and 66 deletions

View File

@ -29,12 +29,14 @@ namespace osu.Game.Beatmaps
{
if (reset)
{
var context = GetContext();
// https://stackoverflow.com/a/10450893
Context.Database.ExecuteSqlCommand("DELETE FROM BeatmapMetadata");
Context.Database.ExecuteSqlCommand("DELETE FROM BeatmapDifficulty");
Context.Database.ExecuteSqlCommand("DELETE FROM BeatmapSetInfo");
Context.Database.ExecuteSqlCommand("DELETE FROM BeatmapSetFileInfo");
Context.Database.ExecuteSqlCommand("DELETE FROM BeatmapInfo");
context.Database.ExecuteSqlCommand("DELETE FROM BeatmapMetadata");
context.Database.ExecuteSqlCommand("DELETE FROM BeatmapDifficulty");
context.Database.ExecuteSqlCommand("DELETE FROM BeatmapSetInfo");
context.Database.ExecuteSqlCommand("DELETE FROM BeatmapSetFileInfo");
context.Database.ExecuteSqlCommand("DELETE FROM BeatmapInfo");
}
}
@ -50,8 +52,10 @@ namespace osu.Game.Beatmaps
/// <param name="beatmapSet">The beatmap to add.</param>
public void Add(BeatmapSetInfo beatmapSet)
{
Context.BeatmapSetInfo.Attach(beatmapSet);
Context.SaveChanges();
var context = GetContext();
context.BeatmapSetInfo.Attach(beatmapSet);
context.SaveChanges();
BeatmapSetAdded?.Invoke(beatmapSet);
}
@ -63,11 +67,13 @@ namespace osu.Game.Beatmaps
/// <returns>Whether the beatmap's <see cref="BeatmapSetInfo.DeletePending"/> was changed.</returns>
public bool Delete(BeatmapSetInfo beatmapSet)
{
var context = GetContext();
if (beatmapSet.DeletePending) return false;
beatmapSet.DeletePending = true;
Context.BeatmapSetInfo.Update(beatmapSet);
Context.SaveChanges();
context.BeatmapSetInfo.Update(beatmapSet);
context.SaveChanges();
BeatmapSetRemoved?.Invoke(beatmapSet);
return true;
@ -80,11 +86,13 @@ namespace osu.Game.Beatmaps
/// <returns>Whether the beatmap's <see cref="BeatmapSetInfo.DeletePending"/> was changed.</returns>
public bool Undelete(BeatmapSetInfo beatmapSet)
{
var context = GetContext();
if (!beatmapSet.DeletePending) return false;
beatmapSet.DeletePending = false;
Context.BeatmapSetInfo.Update(beatmapSet);
Context.SaveChanges();
context.BeatmapSetInfo.Update(beatmapSet);
context.SaveChanges();
BeatmapSetAdded?.Invoke(beatmapSet);
return true;
@ -97,11 +105,13 @@ namespace osu.Game.Beatmaps
/// <returns>Whether the beatmap's <see cref="BeatmapInfo.Hidden"/> was changed.</returns>
public bool Hide(BeatmapInfo beatmap)
{
var context = GetContext();
if (beatmap.Hidden) return false;
beatmap.Hidden = true;
Context.BeatmapInfo.Update(beatmap);
Context.SaveChanges();
context.BeatmapInfo.Update(beatmap);
context.SaveChanges();
BeatmapHidden?.Invoke(beatmap);
return true;
@ -114,11 +124,13 @@ namespace osu.Game.Beatmaps
/// <returns>Whether the beatmap's <see cref="BeatmapInfo.Hidden"/> was changed.</returns>
public bool Restore(BeatmapInfo beatmap)
{
var context = GetContext();
if (!beatmap.Hidden) return false;
beatmap.Hidden = false;
Context.BeatmapInfo.Update(beatmap);
Context.SaveChanges();
context.BeatmapInfo.Update(beatmap);
context.SaveChanges();
BeatmapRestored?.Invoke(beatmap);
return true;
@ -126,21 +138,23 @@ namespace osu.Game.Beatmaps
private void cleanupPendingDeletions()
{
Context.BeatmapSetInfo.RemoveRange(Context.BeatmapSetInfo.Where(b => b.DeletePending && !b.Protected));
Context.SaveChanges();
var context = GetContext();
context.BeatmapSetInfo.RemoveRange(context.BeatmapSetInfo.Where(b => b.DeletePending && !b.Protected));
context.SaveChanges();
}
public IEnumerable<BeatmapSetInfo> BeatmapSets => Context.BeatmapSetInfo
.Include(s => s.Metadata)
.Include(s => s.Beatmaps).ThenInclude(s => s.Ruleset)
.Include(s => s.Beatmaps).ThenInclude(b => b.Difficulty)
.Include(s => s.Beatmaps).ThenInclude(b => b.Metadata)
.Include(s => s.Files).ThenInclude(f => f.FileInfo);
public IEnumerable<BeatmapSetInfo> BeatmapSets => GetContext().BeatmapSetInfo
.Include(s => s.Metadata)
.Include(s => s.Beatmaps).ThenInclude(s => s.Ruleset)
.Include(s => s.Beatmaps).ThenInclude(b => b.Difficulty)
.Include(s => s.Beatmaps).ThenInclude(b => b.Metadata)
.Include(s => s.Files).ThenInclude(f => f.FileInfo);
public IEnumerable<BeatmapInfo> Beatmaps => Context.BeatmapInfo
.Include(b => b.BeatmapSet).ThenInclude(s => s.Metadata)
.Include(b => b.Metadata)
.Include(b => b.Ruleset)
.Include(b => b.Difficulty);
public IEnumerable<BeatmapInfo> Beatmaps => GetContext().BeatmapInfo
.Include(b => b.BeatmapSet).ThenInclude(s => s.Metadata)
.Include(b => b.Metadata)
.Include(b => b.Ruleset)
.Include(b => b.Difficulty);
}
}

View File

@ -11,14 +11,12 @@ namespace osu.Game.Database
{
protected readonly Storage Storage;
private readonly Func<OsuDbContext> contextSource;
protected readonly Func<OsuDbContext> GetContext;
protected OsuDbContext Context => contextSource();
protected DatabaseBackedStore(Func<OsuDbContext> contextSource, Storage storage = null)
protected DatabaseBackedStore(Func<OsuDbContext> getContext, Storage storage = null)
{
Storage = storage;
this.contextSource = contextSource;
GetContext = getContext;
try
{

View File

@ -22,7 +22,7 @@ namespace osu.Game.IO
public readonly ResourceStore<byte[]> Store;
public FileStore(Func<OsuDbContext> contextSource, Storage storage) : base(contextSource, storage)
public FileStore(Func<OsuDbContext> getContext, Storage storage) : base(getContext, storage)
{
Store = new NamespacedResourceStore<byte[]>(new StorageBackedResourceStore(storage), prefix);
}
@ -34,7 +34,7 @@ namespace osu.Game.IO
if (Storage.ExistsDirectory(prefix))
Storage.DeleteDirectory(prefix);
Context.Database.ExecuteSqlCommand("DELETE FROM FileInfo");
GetContext().Database.ExecuteSqlCommand("DELETE FROM FileInfo");
}
}
@ -46,9 +46,11 @@ namespace osu.Game.IO
public FileInfo Add(Stream data, bool reference = true)
{
var context = GetContext();
string hash = data.ComputeSHA2Hash();
var existing = Context.FileInfo.FirstOrDefault(f => f.Hash == hash);
var existing = context.FileInfo.FirstOrDefault(f => f.Hash == hash);
var info = existing ?? new FileInfo { Hash = hash };
@ -71,38 +73,44 @@ namespace osu.Game.IO
return info;
}
public void Reference(params FileInfo[] files)
public void Reference(params FileInfo[] files) => reference(GetContext(), files);
private void reference(OsuDbContext context, FileInfo[] files)
{
foreach (var f in files.GroupBy(f => f.ID))
{
var refetch = Context.Find<FileInfo>(f.First().ID) ?? f.First();
var refetch = context.Find<FileInfo>(f.First().ID) ?? f.First();
refetch.ReferenceCount += f.Count();
Context.FileInfo.Update(refetch);
context.FileInfo.Update(refetch);
}
Context.SaveChanges();
context.SaveChanges();
}
public void Dereference(params FileInfo[] files)
public void Dereference(params FileInfo[] files) => dereference(GetContext(), files);
private void dereference(OsuDbContext context, FileInfo[] files)
{
foreach (var f in files.GroupBy(f => f.ID))
{
var refetch = Context.Find<FileInfo>(f.First().ID);
var refetch = context.Find<FileInfo>(f.First().ID);
refetch.ReferenceCount -= f.Count();
Context.Update(refetch);
context.Update(refetch);
}
Context.SaveChanges();
context.SaveChanges();
}
private void deletePending()
{
foreach (var f in Context.FileInfo.Where(f => f.ReferenceCount < 1))
var context = GetContext();
foreach (var f in context.FileInfo.Where(f => f.ReferenceCount < 1))
{
try
{
Storage.Delete(Path.Combine(prefix, f.StoragePath));
Context.FileInfo.Remove(f);
context.FileInfo.Remove(f);
}
catch (Exception e)
{
@ -110,7 +118,7 @@ namespace osu.Game.IO
}
}
Context.SaveChanges();
context.SaveChanges();
}
}
}

View File

@ -15,9 +15,16 @@ namespace osu.Game.Input
{
public class KeyBindingStore : DatabaseBackedStore
{
public KeyBindingStore(Func<OsuDbContext> contextSource, RulesetStore rulesets, Storage storage = null)
: base(contextSource, storage)
/// <summary>
/// As we do a lot of lookups, let's share a context between them to hopefully improve performance.
/// </summary>
private readonly OsuDbContext queryContext;
public KeyBindingStore(Func<OsuDbContext> getContext, RulesetStore rulesets, Storage storage = null)
: base(getContext, storage)
{
queryContext = GetContext();
foreach (var info in rulesets.AvailableRulesets)
{
var ruleset = info.CreateInstance();
@ -31,15 +38,17 @@ namespace osu.Game.Input
protected override void Prepare(bool reset = false)
{
if (reset)
Context.Database.ExecuteSqlCommand("DELETE FROM KeyBinding");
GetContext().Database.ExecuteSqlCommand("DELETE FROM KeyBinding");
}
private void insertDefaults(IEnumerable<KeyBinding> defaults, int? rulesetId = null, int? variant = null)
{
var context = GetContext();
// compare counts in database vs defaults
foreach (var group in defaults.GroupBy(k => k.Action))
{
int count = Query(rulesetId, variant).Count(k => (int)k.Action == (int)group.Key);
int count = query(context, rulesetId, variant).Count(k => (int)k.Action == (int)group.Key);
int aimCount = group.Count();
if (aimCount <= count)
@ -47,7 +56,7 @@ namespace osu.Game.Input
foreach (var insertable in group.Skip(count).Take(aimCount - count))
// insert any defaults which are missing.
Context.DatabasedKeyBinding.Add(new DatabasedKeyBinding
context.DatabasedKeyBinding.Add(new DatabasedKeyBinding
{
KeyCombination = insertable.KeyCombination,
Action = insertable.Action,
@ -56,7 +65,7 @@ namespace osu.Game.Input
});
}
Context.SaveChanges();
context.SaveChanges();
}
/// <summary>
@ -65,12 +74,16 @@ namespace osu.Game.Input
/// <param name="rulesetId">The ruleset's internal ID.</param>
/// <param name="variant">An optional variant.</param>
/// <returns></returns>
public IEnumerable<KeyBinding> Query(int? rulesetId = null, int? variant = null) => Context.DatabasedKeyBinding.Where(b => b.RulesetID == rulesetId && b.Variant == variant);
public IEnumerable<KeyBinding> Query(int? rulesetId = null, int? variant = null) => query(queryContext, rulesetId, variant);
private IEnumerable<KeyBinding> query(OsuDbContext context, int? rulesetId = null, int? variant = null) =>
context.DatabasedKeyBinding.Where(b => b.RulesetID == rulesetId && b.Variant == variant);
public void Update(KeyBinding keyBinding)
{
Context.Update(keyBinding);
Context.SaveChanges();
var context = GetContext();
context.Update(keyBinding);
context.SaveChanges();
}
}
}

View File

@ -41,7 +41,7 @@ namespace osu.Game.Rulesets
/// <summary>
/// All available rulesets.
/// </summary>
public IEnumerable<RulesetInfo> AvailableRulesets => Context.RulesetInfo.Where(r => r.Available);
public IEnumerable<RulesetInfo> AvailableRulesets => GetContext().RulesetInfo.Where(r => r.Available);
private static Assembly currentDomain_AssemblyResolve(object sender, ResolveEventArgs args) => loaded_assemblies.Keys.FirstOrDefault(a => a.FullName == args.Name);
@ -49,9 +49,11 @@ namespace osu.Game.Rulesets
protected override void Prepare(bool reset = false)
{
var context = GetContext();
if (reset)
{
Context.Database.ExecuteSqlCommand("DELETE FROM RulesetInfo");
context.Database.ExecuteSqlCommand("DELETE FROM RulesetInfo");
}
var instances = loaded_assemblies.Values.Select(r => (Ruleset)Activator.CreateInstance(r, new RulesetInfo())).ToList();
@ -60,29 +62,29 @@ namespace osu.Game.Rulesets
foreach (var r in instances.Where(r => r.LegacyID >= 0).OrderBy(r => r.LegacyID))
{
var rulesetInfo = createRulesetInfo(r);
if (Context.RulesetInfo.SingleOrDefault(rsi => rsi.ID == rulesetInfo.ID) == null)
if (context.RulesetInfo.SingleOrDefault(rsi => rsi.ID == rulesetInfo.ID) == null)
{
Context.RulesetInfo.Add(rulesetInfo);
context.RulesetInfo.Add(rulesetInfo);
}
}
Context.SaveChanges();
context.SaveChanges();
//add any other modes
foreach (var r in instances.Where(r => r.LegacyID < 0))
{
var us = createRulesetInfo(r);
var existing = Context.RulesetInfo.FirstOrDefault(ri => ri.InstantiationInfo == us.InstantiationInfo);
var existing = context.RulesetInfo.FirstOrDefault(ri => ri.InstantiationInfo == us.InstantiationInfo);
if (existing == null)
Context.RulesetInfo.Add(us);
context.RulesetInfo.Add(us);
}
Context.SaveChanges();
context.SaveChanges();
//perform a consistency check
foreach (var r in Context.RulesetInfo)
foreach (var r in context.RulesetInfo)
{
try
{
@ -95,7 +97,7 @@ namespace osu.Game.Rulesets
}
}
Context.SaveChanges();
context.SaveChanges();
}
private static void loadRulesetFromFile(string file)