diff --git a/DysonNetwork.Sphere/Auth/OpenId/ConnectionController.cs b/DysonNetwork.Sphere/Auth/OpenId/ConnectionController.cs index 0799ba3..07c319c 100644 --- a/DysonNetwork.Sphere/Auth/OpenId/ConnectionController.cs +++ b/DysonNetwork.Sphere/Auth/OpenId/ConnectionController.cs @@ -195,9 +195,15 @@ public class ConnectionController( return await HandleManualConnection(provider, oidcService, callbackData, accountId); } - private async Task HandleManualConnection(string provider, OidcService oidcService, - OidcCallbackData callbackData, Guid accountId) + private async Task HandleManualConnection( + string provider, + OidcService oidcService, + OidcCallbackData callbackData, + Guid accountId + ) { + provider = provider.ToLower(); + OidcUserInfo userInfo; try { @@ -216,7 +222,7 @@ public class ConnectionController( // Check if this provider account is already connected to any user var existingConnection = await db.AccountConnections .FirstOrDefaultAsync(c => - c.Provider.Equals(provider, StringComparison.OrdinalIgnoreCase) && + c.Provider == provider && c.ProvidedIdentifier == userInfo.UserId); // If it's connected to a different user, return error @@ -229,7 +235,7 @@ public class ConnectionController( var userHasProvider = await db.AccountConnections .AnyAsync(c => c.AccountId == accountId && - c.Provider.Equals(provider, StringComparison.OrdinalIgnoreCase)); + c.Provider == provider); if (userHasProvider) { @@ -237,7 +243,7 @@ public class ConnectionController( var connection = await db.AccountConnections .FirstOrDefaultAsync(c => c.AccountId == accountId && - c.Provider.Equals(provider, StringComparison.OrdinalIgnoreCase)); + c.Provider == provider); if (connection != null) { @@ -342,18 +348,18 @@ public class ConnectionController( switch (request.Method) { case "GET": - data.Code = request.Query["code"].FirstOrDefault() ?? ""; - data.IdToken = request.Query["id_token"].FirstOrDefault() ?? ""; - data.State = request.Query["state"].FirstOrDefault(); + data.Code = Uri.UnescapeDataString(request.Query["code"].FirstOrDefault() ?? ""); + data.IdToken = Uri.UnescapeDataString(request.Query["id_token"].FirstOrDefault() ?? ""); + data.State = Uri.UnescapeDataString(request.Query["state"].FirstOrDefault() ?? ""); break; case "POST" when request.HasFormContentType: { var form = await request.ReadFormAsync(); - data.Code = form["code"].FirstOrDefault() ?? ""; - data.IdToken = form["id_token"].FirstOrDefault() ?? ""; - data.State = form["state"].FirstOrDefault(); + data.Code = Uri.UnescapeDataString(form["code"].FirstOrDefault() ?? ""); + data.IdToken = Uri.UnescapeDataString(form["id_token"].FirstOrDefault() ?? ""); + data.State = Uri.UnescapeDataString(form["state"].FirstOrDefault() ?? ""); if (form.ContainsKey("user")) - data.RawData = form["user"].FirstOrDefault(); + data.RawData = Uri.UnescapeDataString(form["user"].FirstOrDefault() ?? ""); break; }