♻️ Refactor the way to handle websocket

This commit is contained in:
2025-09-21 23:07:20 +08:00
parent e3657386cd
commit 204640a759
13 changed files with 196 additions and 219 deletions

View File

@@ -1,5 +1,6 @@
using System.Net.WebSockets;
using DysonNetwork.Shared.Proto;
using WebSocketPacket = DysonNetwork.Shared.Data.WebSocketPacket;
namespace DysonNetwork.Ring.Connection;

View File

@@ -3,11 +3,15 @@ using DysonNetwork.Shared.Proto;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
using Swashbuckle.AspNetCore.Annotations;
using WebSocketPacket = DysonNetwork.Shared.Data.WebSocketPacket;
namespace DysonNetwork.Ring.Connection;
[ApiController]
public class WebSocketController(WebSocketService ws, ILogger<WebSocketContext> logger) : ControllerBase
public class WebSocketController(
WebSocketService ws,
ILogger<WebSocketContext> logger
) : ControllerBase
{
[Route("/ws")]
[Authorize]
@@ -23,7 +27,7 @@ public class WebSocketController(WebSocketService ws, ILogger<WebSocketContext>
return;
}
var accountId = currentUser.Id!;
var accountId = Guid.Parse(currentUser.Id!);
var deviceId = currentSession.Challenge?.DeviceId ?? Guid.NewGuid().ToString();
if (string.IsNullOrEmpty(deviceId))
@@ -89,7 +93,7 @@ public class WebSocketController(WebSocketService ws, ILogger<WebSocketContext>
CancellationToken cancellationToken
)
{
var connectionKey = (AccountId: currentUser.Id, DeviceId: deviceId);
var connectionKey = (AccountId: Guid.Parse(currentUser.Id), DeviceId: deviceId);
var buffer = new byte[1024 * 4];
try

View File

@@ -1,97 +0,0 @@
using System.Text.Json;
using System.Text.Json.Serialization;
using DysonNetwork.Shared.Proto;
using NodaTime;
using NodaTime.Serialization.SystemTextJson;
namespace DysonNetwork.Ring.Connection;
public class WebSocketPacket
{
public string Type { get; set; } = null!;
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public object? Data { get; set; } = null!;
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? Endpoint { get; set; }
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? ErrorMessage { get; set; }
/// <summary>
/// Creates a WebSocketPacket from raw WebSocket message bytes
/// </summary>
/// <param name="bytes">Raw WebSocket message bytes</param>
/// <returns>Deserialized WebSocketPacket</returns>
public static WebSocketPacket FromBytes(byte[] bytes)
{
var json = System.Text.Encoding.UTF8.GetString(bytes);
var jsonOpts = new JsonSerializerOptions
{
NumberHandling = JsonNumberHandling.AllowNamedFloatingPointLiterals,
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
DictionaryKeyPolicy = JsonNamingPolicy.SnakeCaseLower,
};
return JsonSerializer.Deserialize<WebSocketPacket>(json, jsonOpts) ??
throw new JsonException("Failed to deserialize WebSocketPacket");
}
/// <summary>
/// Deserializes the Data property to the specified type T
/// </summary>
/// <typeparam name="T">Target type to deserialize to</typeparam>
/// <returns>Deserialized data of type T</returns>
public T? GetData<T>()
{
if (Data is T typedData)
return typedData;
var jsonOpts = new JsonSerializerOptions
{
NumberHandling = JsonNumberHandling.AllowNamedFloatingPointLiterals,
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
DictionaryKeyPolicy = JsonNamingPolicy.SnakeCaseLower,
};
return JsonSerializer.Deserialize<T>(
JsonSerializer.Serialize(Data, jsonOpts),
jsonOpts
);
}
/// <summary>
/// Serializes this WebSocketPacket to a byte array for sending over WebSocket
/// </summary>
/// <returns>Byte array representation of the packet</returns>
public byte[] ToBytes()
{
var jsonOpts = new JsonSerializerOptions
{
NumberHandling = JsonNumberHandling.AllowNamedFloatingPointLiterals,
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
DictionaryKeyPolicy = JsonNamingPolicy.SnakeCaseLower,
}.ConfigureForNodaTime(DateTimeZoneProviders.Tzdb);
var json = JsonSerializer.Serialize(this, jsonOpts);
return System.Text.Encoding.UTF8.GetBytes(json);
}
public Shared.Proto.WebSocketPacket ToProtoValue()
{
return new Shared.Proto.WebSocketPacket
{
Type = Type,
Data = GrpcTypeHelper.ConvertObjectToByteString(Data),
ErrorMessage = ErrorMessage
};
}
public static WebSocketPacket FromProtoValue(Shared.Proto.WebSocketPacket packet)
{
return new WebSocketPacket
{
Type = packet.Type,
Data = GrpcTypeHelper.ConvertByteStringToObject<object?>(packet.Data),
ErrorMessage = packet.ErrorMessage
};
}
}

View File

@@ -2,36 +2,38 @@ using System.Collections.Concurrent;
using System.Net.WebSockets;
using DysonNetwork.Shared.Data;
using DysonNetwork.Shared.Proto;
using Grpc.Core;
using DysonNetwork.Shared.Stream;
using NATS.Client.Core;
using WebSocketPacket = DysonNetwork.Shared.Data.WebSocketPacket;
namespace DysonNetwork.Ring.Connection;
public class WebSocketService
{
private readonly IConfiguration _configuration;
private readonly INatsConnection _nats;
private readonly ILogger<WebSocketService> _logger;
private readonly IDictionary<string, IWebSocketPacketHandler> _handlerMap;
public WebSocketService(
IEnumerable<IWebSocketPacketHandler> handlers,
ILogger<WebSocketService> logger,
IConfiguration configuration
INatsConnection nats
)
{
_logger = logger;
_configuration = configuration;
_handlerMap = handlers.ToDictionary(h => h.PacketType);
_nats = nats;
}
private static readonly ConcurrentDictionary<
(string AccountId, string DeviceId),
(Guid AccountId, string DeviceId),
(WebSocket Socket, CancellationTokenSource Cts)
> ActiveConnections = new();
private static readonly ConcurrentDictionary<string, string> ActiveSubscriptions = new(); // deviceId -> chatRoomId
public bool TryAdd(
(string AccountId, string DeviceId) key,
(Guid AccountId, string DeviceId) key,
WebSocket socket,
CancellationTokenSource cts
)
@@ -42,7 +44,7 @@ public class WebSocketService
return ActiveConnections.TryAdd(key, (socket, cts));
}
public void Disconnect((string AccountId, string DeviceId) key, string? reason = null)
public void Disconnect((Guid AccountId, string DeviceId) key, string? reason = null)
{
if (!ActiveConnections.TryGetValue(key, out var data)) return;
try
@@ -63,19 +65,19 @@ public class WebSocketService
ActiveConnections.TryRemove(key, out _);
}
public bool GetDeviceIsConnected(string deviceId)
public static bool GetDeviceIsConnected(string deviceId)
{
return ActiveConnections.Any(c => c.Key.DeviceId == deviceId);
}
public bool GetAccountIsConnected(string accountId)
public static bool GetAccountIsConnected(Guid accountId)
{
return ActiveConnections.Any(c => c.Key.AccountId == accountId);
}
public void SendPacketToAccount(string userId, WebSocketPacket packet)
public static void SendPacketToAccount(Guid accountId, WebSocketPacket packet)
{
var connections = ActiveConnections.Where(c => c.Key.AccountId == userId);
var connections = ActiveConnections.Where(c => c.Key.AccountId == accountId);
var packetBytes = packet.ToBytes();
var segment = new ArraySegment<byte>(packetBytes);
@@ -139,28 +141,16 @@ public class WebSocketService
try
{
var endpoint = packet.Endpoint.Replace("DysonNetwork.", "").ToLower();
var serviceUrl = "https://_grpc." + endpoint;
var callInvoker = GrpcClientHelper.CreateCallInvoker(serviceUrl);
var client = new RingHandlerService.RingHandlerServiceClient(callInvoker);
try
await _nats.PublishAsync(WebSocketPacketEvent.SubjectPrefix + endpoint, new WebSocketPacketEvent
{
await client.ReceiveWebSocketPacketAsync(new ReceiveWebSocketPacketRequest
{
Account = currentUser,
DeviceId = deviceId,
Packet = packet.ToProtoValue()
});
}
catch (RpcException ex)
{
_logger.LogError(ex, $"Error forwarding packet to endpoint: {packet.Endpoint} (${endpoint})");
}
AccountId = Guid.Parse(currentUser.Id),
DeviceId = deviceId,
PacketBytes = packet.ToBytes()
});
}
catch (Exception ex)
{
_logger.LogError(ex, $"Error forwarding packet to endpoint: {packet.Endpoint}");
_logger.LogError(ex, "Error forwarding packet to endpoint: {Endpoint}", packet.Endpoint);
}
}
@@ -175,4 +165,4 @@ public class WebSocketService
CancellationToken.None
);
}
}
}