From 33bb1212d1e84c3d887380d27f6159ca95361ad5 Mon Sep 17 00:00:00 2001
From: Dan Balasescu <smoogipoo@smgi.me>
Date: Fri, 28 Oct 2022 16:19:15 +0900
Subject: [PATCH] Add notifications websocket + chat implementation

---
 osu.Game/Online/Chat/Message.cs               |   7 ++
 .../Online/Notifications/EndChatRequest.cs    |  16 +++
 .../Notifications/NewChatMessageData.cs       |  29 +++++
 .../Notifications/NotificationsClient.cs      | 113 ++++++++++++++++++
 .../NotificationsClientConnector.cs           |  65 ++++++++++
 .../NotificationsClient_Processing.cs         | 102 ++++++++++++++++
 .../Online/Notifications/SocketMessage.cs     |  21 ++++
 .../Online/Notifications/StartChatRequest.cs  |  16 +++
 8 files changed, 369 insertions(+)
 create mode 100644 osu.Game/Online/Notifications/EndChatRequest.cs
 create mode 100644 osu.Game/Online/Notifications/NewChatMessageData.cs
 create mode 100644 osu.Game/Online/Notifications/NotificationsClient.cs
 create mode 100644 osu.Game/Online/Notifications/NotificationsClientConnector.cs
 create mode 100644 osu.Game/Online/Notifications/NotificationsClient_Processing.cs
 create mode 100644 osu.Game/Online/Notifications/SocketMessage.cs
 create mode 100644 osu.Game/Online/Notifications/StartChatRequest.cs

diff --git a/osu.Game/Online/Chat/Message.cs b/osu.Game/Online/Chat/Message.cs
index 86562341eb..25c5b0853f 100644
--- a/osu.Game/Online/Chat/Message.cs
+++ b/osu.Game/Online/Chat/Message.cs
@@ -30,6 +30,13 @@ namespace osu.Game.Online.Chat
         [JsonProperty(@"sender")]
         public APIUser Sender;
 
+        [JsonProperty(@"sender_id")]
+        public int SenderId
+        {
+            get => Sender?.Id ?? 0;
+            set => Sender = new APIUser { Id = value };
+        }
+
         [JsonConstructor]
         public Message()
         {
diff --git a/osu.Game/Online/Notifications/EndChatRequest.cs b/osu.Game/Online/Notifications/EndChatRequest.cs
new file mode 100644
index 0000000000..1173b1e8d0
--- /dev/null
+++ b/osu.Game/Online/Notifications/EndChatRequest.cs
@@ -0,0 +1,16 @@
+// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
+// See the LICENCE file in the repository root for full licence text.
+
+using Newtonsoft.Json;
+
+namespace osu.Game.Online.Notifications
+{
+    [JsonObject(MemberSerialization.OptIn)]
+    public class EndChatRequest : SocketMessage
+    {
+        public EndChatRequest()
+        {
+            Event = "chat.end";
+        }
+    }
+}
diff --git a/osu.Game/Online/Notifications/NewChatMessageData.cs b/osu.Game/Online/Notifications/NewChatMessageData.cs
new file mode 100644
index 0000000000..b388afa743
--- /dev/null
+++ b/osu.Game/Online/Notifications/NewChatMessageData.cs
@@ -0,0 +1,29 @@
+// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
+// See the LICENCE file in the repository root for full licence text.
+
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.Serialization;
+using Newtonsoft.Json;
+using osu.Game.Online.API.Requests.Responses;
+using osu.Game.Online.Chat;
+
+namespace osu.Game.Online.Notifications
+{
+    [JsonObject(MemberSerialization.OptIn)]
+    public class NewChatMessageData
+    {
+        [JsonProperty("messages")]
+        public List<Message> Messages { get; set; } = null!;
+
+        [JsonProperty("users")]
+        private List<APIUser> users { get; set; } = null!;
+
+        [OnDeserialized]
+        private void onDeserialised(StreamingContext context)
+        {
+            foreach (var m in Messages)
+                m.Sender = users.Single(u => u.OnlineID == m.SenderId);
+        }
+    }
+}
diff --git a/osu.Game/Online/Notifications/NotificationsClient.cs b/osu.Game/Online/Notifications/NotificationsClient.cs
new file mode 100644
index 0000000000..63260e5df9
--- /dev/null
+++ b/osu.Game/Online/Notifications/NotificationsClient.cs
@@ -0,0 +1,113 @@
+// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
+// See the LICENCE file in the repository root for full licence text.
+
+using System;
+using System.Diagnostics;
+using System.Net.WebSockets;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using Newtonsoft.Json;
+using osu.Framework.Extensions.TypeExtensions;
+using osu.Framework.Logging;
+using osu.Game.Online.API;
+
+namespace osu.Game.Online.Notifications
+{
+    public partial class NotificationsClient : SocketClient
+    {
+        private readonly ClientWebSocket socket;
+        private readonly string endpoint;
+        private readonly IAPIProvider api;
+
+        public NotificationsClient(ClientWebSocket socket, string endpoint, IAPIProvider api)
+        {
+            this.socket = socket;
+            this.endpoint = endpoint;
+            this.api = api;
+        }
+
+        public override async Task StartAsync(CancellationToken cancellationToken)
+        {
+            await socket.ConnectAsync(new Uri(endpoint), cancellationToken).ConfigureAwait(false);
+            await onConnectedAsync();
+            runReadLoop(cancellationToken);
+        }
+
+        private void runReadLoop(CancellationToken cancellationToken) => Task.Run((Func<Task>)(async () =>
+        {
+            byte[] buffer = new byte[1024];
+            StringBuilder messageResult = new StringBuilder();
+
+            while (!cancellationToken.IsCancellationRequested)
+            {
+                try
+                {
+                    WebSocketReceiveResult result = await socket.ReceiveAsync(buffer, cancellationToken);
+
+                    switch (result.MessageType)
+                    {
+                        case WebSocketMessageType.Text:
+                            messageResult.Append(Encoding.UTF8.GetString(buffer[..result.Count]));
+
+                            if (result.EndOfMessage)
+                            {
+                                SocketMessage? message = JsonConvert.DeserializeObject<SocketMessage>(messageResult.ToString());
+                                messageResult.Clear();
+
+                                Debug.Assert(message != null);
+
+                                if (message.Error != null)
+                                {
+                                    Logger.Log($"{GetType().ReadableName()} error: {message.Error}", LoggingTarget.Network);
+                                    break;
+                                }
+
+                                await onMessageReceivedAsync(message);
+                            }
+
+                            break;
+
+                        case WebSocketMessageType.Binary:
+                            throw new NotImplementedException();
+
+                        case WebSocketMessageType.Close:
+                            throw new Exception("Connection closed by remote host.");
+                    }
+                }
+                catch (Exception ex)
+                {
+                    await InvokeClosed(ex);
+                    return;
+                }
+            }
+        }), cancellationToken);
+
+        private async Task closeAsync()
+        {
+            try
+            {
+                await socket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Disconnecting", CancellationToken.None).ConfigureAwait(false);
+            }
+            catch
+            {
+                // Closure can fail if the connection is aborted. Don't really care since it's disposed anyway.
+            }
+        }
+
+        private async Task sendMessage(SocketMessage message, CancellationToken cancellationToken)
+        {
+            if (socket.State != WebSocketState.Open)
+                return;
+
+            await socket.SendAsync(Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(message)), WebSocketMessageType.Text, true, cancellationToken);
+        }
+
+        public override async ValueTask DisposeAsync()
+        {
+            await base.DisposeAsync();
+            await closeAsync();
+            socket.Dispose();
+        }
+    }
+}
diff --git a/osu.Game/Online/Notifications/NotificationsClientConnector.cs b/osu.Game/Online/Notifications/NotificationsClientConnector.cs
new file mode 100644
index 0000000000..18b2a1b19d
--- /dev/null
+++ b/osu.Game/Online/Notifications/NotificationsClientConnector.cs
@@ -0,0 +1,65 @@
+// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
+// See the LICENCE file in the repository root for full licence text.
+
+using System;
+using System.Collections.Generic;
+using System.Net;
+using System.Net.WebSockets;
+using System.Threading;
+using System.Threading.Tasks;
+using osu.Game.Online.API;
+using osu.Game.Online.API.Requests;
+using osu.Game.Online.Chat;
+
+namespace osu.Game.Online.Notifications
+{
+    public class NotificationsClientConnector : SocketClientConnector
+    {
+        public event Action<Channel>? ChannelJoined;
+        public event Action<List<Message>>? NewMessages;
+        public event Action? PresenceReceived;
+
+        private readonly IAPIProvider api;
+        private bool chatStarted;
+
+        public NotificationsClientConnector(IAPIProvider api)
+            : base(api)
+        {
+            this.api = api;
+        }
+
+        public void StartChat()
+        {
+            chatStarted = true;
+
+            if (CurrentConnection is NotificationsClient client)
+                client.EnableChat = true;
+        }
+
+        protected override async Task<SocketClient> BuildConnectionAsync(CancellationToken cancellationToken)
+        {
+            var tcs = new TaskCompletionSource<string>();
+
+            var req = new GetNotificationsRequest();
+            req.Success += bundle => tcs.SetResult(bundle.Endpoint);
+            req.Failure += ex => tcs.SetException(ex);
+            api.Queue(req);
+
+            string endpoint = await tcs.Task;
+
+            ClientWebSocket socket = new ClientWebSocket();
+            socket.Options.SetRequestHeader("Authorization", $"Bearer {api.AccessToken}");
+            socket.Options.Proxy = WebRequest.DefaultWebProxy;
+            if (socket.Options.Proxy != null)
+                socket.Options.Proxy.Credentials = CredentialCache.DefaultCredentials;
+
+            return new NotificationsClient(socket, endpoint, api)
+            {
+                ChannelJoined = c => ChannelJoined?.Invoke(c),
+                NewMessages = m => NewMessages?.Invoke(m),
+                PresenceReceived = () => PresenceReceived?.Invoke(),
+                EnableChat = chatStarted
+            };
+        }
+    }
+}
diff --git a/osu.Game/Online/Notifications/NotificationsClient_Processing.cs b/osu.Game/Online/Notifications/NotificationsClient_Processing.cs
new file mode 100644
index 0000000000..4950a53f6f
--- /dev/null
+++ b/osu.Game/Online/Notifications/NotificationsClient_Processing.cs
@@ -0,0 +1,102 @@
+// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
+// See the LICENCE file in the repository root for full licence text.
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using Newtonsoft.Json;
+using osu.Game.Online.API.Requests;
+using osu.Game.Online.Chat;
+
+namespace osu.Game.Online.Notifications
+{
+    public partial class NotificationsClient
+    {
+        public Action<Channel>? ChannelJoined;
+        public Action<List<Message>>? NewMessages;
+        public Action? PresenceReceived;
+
+        private bool enableChat;
+        private long lastMessageId;
+
+        public bool EnableChat
+        {
+            get => enableChat;
+            set
+            {
+                enableChat = value;
+                Task.Run(startChatIfEnabledAsync);
+            }
+        }
+
+        private async Task onConnectedAsync()
+        {
+            await startChatIfEnabledAsync();
+        }
+
+        private async Task startChatIfEnabledAsync()
+        {
+            if (!EnableChat)
+                return;
+
+            await sendMessage(new StartChatRequest(), CancellationToken.None);
+
+            var fetchReq = new GetUpdatesRequest(lastMessageId);
+
+            fetchReq.Success += updates =>
+            {
+                if (updates?.Presence != null)
+                {
+                    foreach (var channel in updates.Presence)
+                        handleJoinedChannel(channel);
+
+                    //todo: handle left channels
+
+                    handleMessages(updates.Messages);
+                }
+
+                PresenceReceived?.Invoke();
+            };
+
+            api.Queue(fetchReq);
+        }
+
+        private Task onMessageReceivedAsync(SocketMessage message)
+        {
+            switch (message.Event)
+            {
+                case "chat.message.new":
+                    Debug.Assert(message.Data != null);
+
+                    NewChatMessageData? messageData = JsonConvert.DeserializeObject<NewChatMessageData>(message.Data.ToString());
+                    Debug.Assert(messageData != null);
+
+                    List<Message> messages = messageData.Messages.Where(m => m.Sender.OnlineID != api.LocalUser.Value.OnlineID).ToList();
+
+                    foreach (var msg in messages)
+                        handleJoinedChannel(new Channel(msg.Sender) { Id = msg.ChannelId });
+
+                    handleMessages(messages);
+                    break;
+            }
+
+            return Task.CompletedTask;
+        }
+
+        private void handleJoinedChannel(Channel channel)
+        {
+            // we received this from the server so should mark the channel already joined.
+            channel.Joined.Value = true;
+            ChannelJoined?.Invoke(channel);
+        }
+
+        private void handleMessages(List<Message> messages)
+        {
+            NewMessages?.Invoke(messages);
+            lastMessageId = messages.LastOrDefault()?.Id ?? lastMessageId;
+        }
+    }
+}
diff --git a/osu.Game/Online/Notifications/SocketMessage.cs b/osu.Game/Online/Notifications/SocketMessage.cs
new file mode 100644
index 0000000000..6b5f3435fc
--- /dev/null
+++ b/osu.Game/Online/Notifications/SocketMessage.cs
@@ -0,0 +1,21 @@
+// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
+// See the LICENCE file in the repository root for full licence text.
+
+using Newtonsoft.Json;
+using Newtonsoft.Json.Linq;
+
+namespace osu.Game.Online.Notifications
+{
+    [JsonObject(MemberSerialization.OptIn)]
+    public class SocketMessage
+    {
+        [JsonProperty("event")]
+        public string Event { get; set; } = null!;
+
+        [JsonProperty("data")]
+        public JObject? Data { get; set; }
+
+        [JsonProperty("error")]
+        public string? Error { get; set; }
+    }
+}
diff --git a/osu.Game/Online/Notifications/StartChatRequest.cs b/osu.Game/Online/Notifications/StartChatRequest.cs
new file mode 100644
index 0000000000..274738886d
--- /dev/null
+++ b/osu.Game/Online/Notifications/StartChatRequest.cs
@@ -0,0 +1,16 @@
+// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
+// See the LICENCE file in the repository root for full licence text.
+
+using Newtonsoft.Json;
+
+namespace osu.Game.Online.Notifications
+{
+    [JsonObject(MemberSerialization.OptIn)]
+    public class StartChatRequest : SocketMessage
+    {
+        public StartChatRequest()
+        {
+            Event = "chat.start";
+        }
+    }
+}