173 lines
		
	
	
		
			5.4 KiB
		
	
	
	
		
			C#
		
	
	
	
	
	
			
		
		
	
	
			173 lines
		
	
	
		
			5.4 KiB
		
	
	
	
		
			C#
		
	
	
	
	
	
using System.Collections.Concurrent;
 | 
						|
using System.Net.WebSockets;
 | 
						|
using DysonNetwork.Shared.Models;
 | 
						|
using DysonNetwork.Shared.Proto;
 | 
						|
using DysonNetwork.Shared.Stream;
 | 
						|
using NATS.Client.Core;
 | 
						|
using WebSocketPacket = DysonNetwork.Shared.Models.WebSocketPacket;
 | 
						|
 | 
						|
namespace DysonNetwork.Ring.Connection;
 | 
						|
 | 
						|
public class WebSocketService
 | 
						|
{
 | 
						|
    private readonly INatsConnection _nats;
 | 
						|
    private readonly ILogger<WebSocketService> _logger;
 | 
						|
    private readonly IDictionary<string, IWebSocketPacketHandler> _handlerMap;
 | 
						|
 | 
						|
    public WebSocketService(
 | 
						|
        IEnumerable<IWebSocketPacketHandler> handlers,
 | 
						|
        ILogger<WebSocketService> logger,
 | 
						|
        INatsConnection nats
 | 
						|
    )
 | 
						|
    {
 | 
						|
        _logger = logger;
 | 
						|
        _handlerMap = handlers.ToDictionary(h => h.PacketType);
 | 
						|
        _nats = nats;
 | 
						|
    }
 | 
						|
 | 
						|
    private static readonly ConcurrentDictionary<
 | 
						|
        (Guid AccountId, string DeviceId),
 | 
						|
        (WebSocket Socket, CancellationTokenSource Cts)
 | 
						|
    > ActiveConnections = new();
 | 
						|
 | 
						|
    private static readonly ConcurrentDictionary<string, string> ActiveSubscriptions = new(); // deviceId -> chatRoomId
 | 
						|
 | 
						|
    public bool TryAdd(
 | 
						|
        (Guid AccountId, string DeviceId) key,
 | 
						|
        WebSocket socket,
 | 
						|
        CancellationTokenSource cts
 | 
						|
    )
 | 
						|
    {
 | 
						|
        if (ActiveConnections.TryGetValue(key, out _))
 | 
						|
            Disconnect(key,
 | 
						|
                "Just connected somewhere else with the same identifier."); // Disconnect the previous one using the same identifier
 | 
						|
        return ActiveConnections.TryAdd(key, (socket, cts));
 | 
						|
    }
 | 
						|
 | 
						|
    public void Disconnect((Guid AccountId, string DeviceId) key, string? reason = null)
 | 
						|
    {
 | 
						|
        if (!ActiveConnections.TryGetValue(key, out var data)) return;
 | 
						|
        try
 | 
						|
        {
 | 
						|
            data.Socket.CloseAsync(
 | 
						|
                WebSocketCloseStatus.NormalClosure,
 | 
						|
                reason ?? "Server just decided to disconnect.",
 | 
						|
                CancellationToken.None
 | 
						|
            );
 | 
						|
        }
 | 
						|
        catch (Exception ex)
 | 
						|
        {
 | 
						|
            _logger.LogWarning(ex, "Error while closing WebSocket for {AccountId}:{DeviceId}", key.AccountId,
 | 
						|
                key.DeviceId);
 | 
						|
        }
 | 
						|
 | 
						|
        data.Cts.Cancel();
 | 
						|
        ActiveConnections.TryRemove(key, out _);
 | 
						|
    }
 | 
						|
 | 
						|
    public static bool GetDeviceIsConnected(string deviceId)
 | 
						|
    {
 | 
						|
        return ActiveConnections.Any(c => c.Key.DeviceId == deviceId);
 | 
						|
    }
 | 
						|
 | 
						|
    public static bool GetAccountIsConnected(Guid accountId)
 | 
						|
    {
 | 
						|
        return ActiveConnections.Any(c => c.Key.AccountId == accountId);
 | 
						|
    }
 | 
						|
 | 
						|
    public static void SendPacketToAccount(Guid accountId, WebSocketPacket packet)
 | 
						|
    {
 | 
						|
        var connections = ActiveConnections.Where(c => c.Key.AccountId == accountId);
 | 
						|
        var packetBytes = packet.ToBytes();
 | 
						|
        var segment = new ArraySegment<byte>(packetBytes);
 | 
						|
 | 
						|
        foreach (var connection in connections)
 | 
						|
        {
 | 
						|
            connection.Value.Socket.SendAsync(
 | 
						|
                segment,
 | 
						|
                WebSocketMessageType.Binary,
 | 
						|
                true,
 | 
						|
                CancellationToken.None
 | 
						|
            );
 | 
						|
        }
 | 
						|
    }
 | 
						|
 | 
						|
    public void SendPacketToDevice(string deviceId, WebSocketPacket packet)
 | 
						|
    {
 | 
						|
        var connections = ActiveConnections.Where(c => c.Key.DeviceId == deviceId);
 | 
						|
        var packetBytes = packet.ToBytes();
 | 
						|
        var segment = new ArraySegment<byte>(packetBytes);
 | 
						|
 | 
						|
        foreach (var connection in connections)
 | 
						|
        {
 | 
						|
            connection.Value.Socket.SendAsync(
 | 
						|
                segment,
 | 
						|
                WebSocketMessageType.Binary,
 | 
						|
                true,
 | 
						|
                CancellationToken.None
 | 
						|
            );
 | 
						|
        }
 | 
						|
    }
 | 
						|
 | 
						|
    public async Task HandlePacket(
 | 
						|
        Account currentUser,
 | 
						|
        string deviceId,
 | 
						|
        WebSocketPacket packet,
 | 
						|
        WebSocket socket
 | 
						|
    )
 | 
						|
    {
 | 
						|
        if (packet.Type == WebSocketPacketType.Ping)
 | 
						|
        {
 | 
						|
            await socket.SendAsync(
 | 
						|
                new ArraySegment<byte>(new WebSocketPacket
 | 
						|
                {
 | 
						|
                    Type = WebSocketPacketType.Pong
 | 
						|
                }.ToBytes()),
 | 
						|
                WebSocketMessageType.Binary,
 | 
						|
                true,
 | 
						|
                CancellationToken.None
 | 
						|
            );
 | 
						|
            return;
 | 
						|
        }
 | 
						|
 | 
						|
        if (_handlerMap.TryGetValue(packet.Type, out var handler))
 | 
						|
        {
 | 
						|
            await handler.HandleAsync(currentUser, deviceId, packet, socket, this);
 | 
						|
            return;
 | 
						|
        }
 | 
						|
 | 
						|
        if (packet.Endpoint is not null)
 | 
						|
        {
 | 
						|
            try
 | 
						|
            {
 | 
						|
                var endpoint = packet.Endpoint.Replace("DysonNetwork.", "").ToLower();
 | 
						|
                await _nats.PublishAsync(
 | 
						|
                    WebSocketPacketEvent.SubjectPrefix + endpoint,
 | 
						|
                    GrpcTypeHelper
 | 
						|
                        .ConvertObjectToByteString(new WebSocketPacketEvent
 | 
						|
                        {
 | 
						|
                            AccountId = Guid.Parse(currentUser.Id),
 | 
						|
                            DeviceId = deviceId,
 | 
						|
                            PacketBytes = packet.ToBytes()
 | 
						|
                        }).ToByteArray()
 | 
						|
                );
 | 
						|
                return;
 | 
						|
            }
 | 
						|
            catch (Exception ex)
 | 
						|
            {
 | 
						|
                _logger.LogError(ex, "Error forwarding packet to endpoint: {Endpoint}", packet.Endpoint);
 | 
						|
            }
 | 
						|
        }
 | 
						|
 | 
						|
        await socket.SendAsync(
 | 
						|
            new ArraySegment<byte>(new WebSocketPacket
 | 
						|
            {
 | 
						|
                Type = WebSocketPacketType.Error,
 | 
						|
                ErrorMessage = $"Unprocessable packet: {packet.Type}"
 | 
						|
            }.ToBytes()),
 | 
						|
            WebSocketMessageType.Binary,
 | 
						|
            true,
 | 
						|
            CancellationToken.None
 | 
						|
        );
 | 
						|
    }
 | 
						|
} |