diff --git a/osu.Game/Database/RealmContextFactory.cs b/osu.Game/Database/RealmContextFactory.cs index 169333bff8..62717eb880 100644 --- a/osu.Game/Database/RealmContextFactory.cs +++ b/osu.Game/Database/RealmContextFactory.cs @@ -2,6 +2,8 @@ // See the LICENCE file in the repository root for full licence text. using System; +using System.Collections.Generic; +using System.Diagnostics; using System.IO; using System.Linq; using System.Reflection; @@ -80,6 +82,10 @@ public Realm Context { context = createContext(); Logger.Log(@$"Opened realm ""{context.Config.DatabasePath}"" at version {context.Config.SchemaVersion}"); + + // Resubscribe any subscriptions + foreach (var action in subscriptionActions.Keys) + registerSubscription(action); } // creating a context will ensure our schema is up-to-date and migrated. @@ -226,26 +232,42 @@ public void Write(Action action) } } + private readonly Dictionary, IDisposable?> subscriptionActions = new Dictionary, IDisposable?>(); + /// /// Run work on realm that will be run every time the update thread realm context gets recycled. /// - /// The work to run. - public void Register(Action action) + /// The work to run. Return value should be an from QueryAsyncWithNotifications, or an to clean up any bindings. + /// An which should be disposed to unsubscribe any inner subscription. + public IDisposable Register(Func action) { - if (!ThreadSafety.IsUpdateThread && context != null) - throw new InvalidOperationException(@$"{nameof(BlockAllOperations)} must be called from the update thread."); + if (!ThreadSafety.IsUpdateThread) + throw new InvalidOperationException(@$"{nameof(Register)} must be called from the update thread."); - if (ThreadSafety.IsUpdateThread) + subscriptionActions.Add(action, null); + registerSubscription(action); + + return new InvokeOnDisposal(() => { - current_thread_subscriptions_allowed.Value = true; - action(Context); - current_thread_subscriptions_allowed.Value = false; - } - else + // TODO: this likely needs to be run on the update thread. + if (subscriptionActions.TryGetValue(action, out var unsubscriptionAction)) + { + unsubscriptionAction?.Dispose(); + subscriptionActions.Remove(action); + } + }); + } + + private void registerSubscription(Func action) + { + Debug.Assert(ThreadSafety.IsUpdateThread); + + lock (contextLock) { + Debug.Assert(context != null); + current_thread_subscriptions_allowed.Value = true; - using (var realm = createContext()) - action(realm); + subscriptionActions[action] = action(context); current_thread_subscriptions_allowed.Value = false; } }