✨ Enrich the user connections
This commit is contained in:
		| @@ -1,15 +0,0 @@ | ||||
| using DysonNetwork.Sphere.Account; | ||||
| using Microsoft.AspNetCore.Mvc; | ||||
| using Microsoft.EntityFrameworkCore; | ||||
| using NodaTime; | ||||
|  | ||||
| namespace DysonNetwork.Sphere.Auth.OpenId; | ||||
|  | ||||
| /// <summary> | ||||
| /// This controller is designed to handle the OAuth callback. | ||||
| /// </summary> | ||||
| [ApiController] | ||||
| [Route("/auth/callback")] | ||||
| public class AuthCallbackController : ControllerBase | ||||
| { | ||||
| } | ||||
| @@ -10,11 +10,10 @@ namespace DysonNetwork.Sphere.Auth.OpenId; | ||||
| [Route("/api/accounts/me/connections")] | ||||
| [Authorize] | ||||
| public class ConnectionController( | ||||
|     AppDatabase db,  | ||||
|     IEnumerable<OidcService> oidcServices,  | ||||
|     AccountService accountService,  | ||||
|     AuthService authService, | ||||
|     IClock clock | ||||
|     AppDatabase db, | ||||
|     IEnumerable<OidcService> oidcServices, | ||||
|     AccountService accounts, | ||||
|     AuthService auth | ||||
| ) : ControllerBase | ||||
| { | ||||
|     [HttpGet] | ||||
| @@ -25,7 +24,7 @@ public class ConnectionController( | ||||
|  | ||||
|         var connections = await db.AccountConnections | ||||
|             .Where(c => c.AccountId == currentUser.Id) | ||||
|             .Select(c => new { c.Id, c.AccountId, c.Provider, c.ProvidedIdentifier }) | ||||
|             .Select(c => new { c.Id, c.AccountId, c.Provider, c.ProvidedIdentifier, c.Meta, c.LastUsedAt }) | ||||
|             .ToListAsync(); | ||||
|         return Ok(connections); | ||||
|     } | ||||
| @@ -63,7 +62,8 @@ public class ConnectionController( | ||||
|         if (HttpContext.Items["CurrentUser"] is not Account.Account currentUser) | ||||
|             return Unauthorized(); | ||||
|  | ||||
|         var oidcService = oidcServices.FirstOrDefault(s => s.ProviderName.Equals(request.Provider, StringComparison.OrdinalIgnoreCase)); | ||||
|         var oidcService = oidcServices.FirstOrDefault(s => | ||||
|             s.ProviderName.Equals(request.Provider, StringComparison.OrdinalIgnoreCase)); | ||||
|         if (oidcService == null) | ||||
|             return BadRequest($"Provider '{request.Provider}' is not supported"); | ||||
|  | ||||
| @@ -94,7 +94,8 @@ public class ConnectionController( | ||||
|     [HttpGet, HttpPost] | ||||
|     public async Task<IActionResult> HandleCallback([FromRoute] string provider) | ||||
|     { | ||||
|         var oidcService = oidcServices.FirstOrDefault(s => s.ProviderName.Equals(provider, StringComparison.OrdinalIgnoreCase)); | ||||
|         var oidcService = | ||||
|             oidcServices.FirstOrDefault(s => s.ProviderName.Equals(provider, StringComparison.OrdinalIgnoreCase)); | ||||
|         if (oidcService == null) | ||||
|             return BadRequest($"Provider '{provider}' is not supported."); | ||||
|  | ||||
| @@ -106,21 +107,19 @@ public class ConnectionController( | ||||
|         HttpContext.Session.Remove($"oidc_state_{callbackData.State}"); | ||||
|  | ||||
|         // If sessionState is present, it's a manual connection flow for an existing user. | ||||
|         if (sessionState != null) | ||||
|         { | ||||
|             var stateParts = sessionState.Split('|'); | ||||
|             if (stateParts.Length != 3 || !stateParts[1].Equals(provider, StringComparison.OrdinalIgnoreCase)) | ||||
|                 return BadRequest("State mismatch."); | ||||
|         if (sessionState == null) return await HandleLoginOrRegistration(provider, oidcService, callbackData); | ||||
|         var stateParts = sessionState.Split('|'); | ||||
|         if (stateParts.Length != 3 || !stateParts[1].Equals(provider, StringComparison.OrdinalIgnoreCase)) | ||||
|             return BadRequest("State mismatch."); | ||||
|  | ||||
|         var accountId = Guid.Parse(stateParts[0]); | ||||
|         return await HandleManualConnection(provider, oidcService, callbackData, accountId); | ||||
|  | ||||
|             var accountId = Guid.Parse(stateParts[0]); | ||||
|             return await HandleManualConnection(provider, oidcService, callbackData, accountId); | ||||
|         } | ||||
|          | ||||
|         // Otherwise, it's a login or registration flow. | ||||
|         return await HandleLoginOrRegistration(provider, oidcService, callbackData); | ||||
|     } | ||||
|  | ||||
|     private async Task<IActionResult> HandleManualConnection(string provider, OidcService oidcService, OidcCallbackData callbackData, Guid accountId) | ||||
|     private async Task<IActionResult> HandleManualConnection(string provider, OidcService oidcService, | ||||
|         OidcCallbackData callbackData, Guid accountId) | ||||
|     { | ||||
|         OidcUserInfo userInfo; | ||||
|         try | ||||
| @@ -133,7 +132,9 @@ public class ConnectionController( | ||||
|         } | ||||
|  | ||||
|         var existingConnection = await db.AccountConnections | ||||
|             .FirstOrDefaultAsync(c => c.Provider.Equals(provider, StringComparison.OrdinalIgnoreCase) && c.ProvidedIdentifier == userInfo.UserId); | ||||
|             .FirstOrDefaultAsync(c => | ||||
|                 c.Provider.Equals(provider, StringComparison.OrdinalIgnoreCase) && | ||||
|                 c.ProvidedIdentifier == userInfo.UserId); | ||||
|  | ||||
|         if (existingConnection != null && existingConnection.AccountId != accountId) | ||||
|         { | ||||
| @@ -141,8 +142,10 @@ public class ConnectionController( | ||||
|         } | ||||
|  | ||||
|         var userConnection = await db.AccountConnections | ||||
|             .FirstOrDefaultAsync(c => c.AccountId == accountId && c.Provider.Equals(provider, StringComparison.OrdinalIgnoreCase)); | ||||
|             .FirstOrDefaultAsync(c => | ||||
|                 c.AccountId == accountId && c.Provider.Equals(provider, StringComparison.OrdinalIgnoreCase)); | ||||
|  | ||||
|         var clock = SystemClock.Instance; | ||||
|         if (userConnection != null) | ||||
|         { | ||||
|             userConnection.AccessToken = userInfo.AccessToken; | ||||
| @@ -158,7 +161,8 @@ public class ConnectionController( | ||||
|                 ProvidedIdentifier = userInfo.UserId!, | ||||
|                 AccessToken = userInfo.AccessToken, | ||||
|                 RefreshToken = userInfo.RefreshToken, | ||||
|                 LastUsedAt = clock.GetCurrentInstant() | ||||
|                 LastUsedAt = clock.GetCurrentInstant(), | ||||
|                 Meta = userInfo.ToMetadata(), | ||||
|             }); | ||||
|         } | ||||
|  | ||||
| @@ -170,7 +174,11 @@ public class ConnectionController( | ||||
|         return Redirect(string.IsNullOrEmpty(returnUrl) ? "/" : returnUrl); | ||||
|     } | ||||
|  | ||||
|     private async Task<IActionResult> HandleLoginOrRegistration(string provider, OidcService oidcService, OidcCallbackData callbackData) | ||||
|     private async Task<IActionResult> HandleLoginOrRegistration( | ||||
|         string provider, | ||||
|         OidcService oidcService, | ||||
|         OidcCallbackData callbackData | ||||
|     ) | ||||
|     { | ||||
|         OidcUserInfo userInfo; | ||||
|         try | ||||
| @@ -191,25 +199,17 @@ public class ConnectionController( | ||||
|             .Include(c => c.Account) | ||||
|             .FirstOrDefaultAsync(c => c.Provider == provider && c.ProvidedIdentifier == userInfo.UserId); | ||||
|  | ||||
|         var clock = SystemClock.Instance; | ||||
|         if (connection != null) | ||||
|         { | ||||
|             // Login existing user | ||||
|             var session = await authService.CreateSessionAsync(connection.Account, clock.GetCurrentInstant()); | ||||
|             var token = authService.CreateToken(session); | ||||
|             var session = await auth.CreateSessionAsync(connection.Account, clock.GetCurrentInstant()); | ||||
|             var token = auth.CreateToken(session); | ||||
|             return Redirect($"/?token={token}"); | ||||
|         } | ||||
|  | ||||
|         var account = await accountService.LookupAccount(userInfo.Email); | ||||
|         if (account == null) | ||||
|         { | ||||
|             // Register new user | ||||
|             account = await accountService.CreateAccount(userInfo); | ||||
|         } | ||||
|  | ||||
|         if (account == null) | ||||
|         { | ||||
|             return BadRequest("Unable to create or link account."); | ||||
|         } | ||||
|         // Register new user | ||||
|         var account = await accounts.LookupAccount(userInfo.Email) ?? await accounts.CreateAccount(userInfo); | ||||
|  | ||||
|         // Create connection for new or existing user | ||||
|         var newConnection = new AccountConnection | ||||
| @@ -219,40 +219,43 @@ public class ConnectionController( | ||||
|             ProvidedIdentifier = userInfo.UserId!, | ||||
|             AccessToken = userInfo.AccessToken, | ||||
|             RefreshToken = userInfo.RefreshToken, | ||||
|             LastUsedAt = clock.GetCurrentInstant() | ||||
|             LastUsedAt = clock.GetCurrentInstant(), | ||||
|             Meta = userInfo.ToMetadata() | ||||
|         }; | ||||
|         db.AccountConnections.Add(newConnection); | ||||
|  | ||||
|         await db.SaveChangesAsync(); | ||||
|  | ||||
|         var loginSession = await authService.CreateSessionAsync(account, clock.GetCurrentInstant()); | ||||
|         var loginToken = authService.CreateToken(loginSession); | ||||
|         var loginSession = await auth.CreateSessionAsync(account, clock.GetCurrentInstant()); | ||||
|         var loginToken = auth.CreateToken(loginSession); | ||||
|         return Redirect($"/?token={loginToken}"); | ||||
|     } | ||||
|  | ||||
|     private async Task<OidcCallbackData> ExtractCallbackData(HttpRequest request) | ||||
|     private static async Task<OidcCallbackData> ExtractCallbackData(HttpRequest request) | ||||
|     { | ||||
|         var data = new OidcCallbackData(); | ||||
|         if (request.Method == "GET") | ||||
|         switch (request.Method) | ||||
|         { | ||||
|             data.Code = request.Query["code"].FirstOrDefault() ?? ""; | ||||
|             data.IdToken = request.Query["id_token"].FirstOrDefault() ?? ""; | ||||
|             data.State = request.Query["state"].FirstOrDefault(); | ||||
|         } | ||||
|         else if (request.Method == "POST" && request.HasFormContentType) | ||||
|         { | ||||
|             var form = await request.ReadFormAsync(); | ||||
|             data.Code = form["code"].FirstOrDefault() ?? ""; | ||||
|             data.IdToken = form["id_token"].FirstOrDefault() ?? ""; | ||||
|             data.State = form["state"].FirstOrDefault(); | ||||
|             if (form.ContainsKey("user")) | ||||
|             case "GET": | ||||
|                 data.Code = request.Query["code"].FirstOrDefault() ?? ""; | ||||
|                 data.IdToken = request.Query["id_token"].FirstOrDefault() ?? ""; | ||||
|                 data.State = request.Query["state"].FirstOrDefault(); | ||||
|                 break; | ||||
|             case "POST" when request.HasFormContentType: | ||||
|             { | ||||
|                 data.RawData = form["user"].FirstOrDefault(); | ||||
|                 var form = await request.ReadFormAsync(); | ||||
|                 data.Code = form["code"].FirstOrDefault() ?? ""; | ||||
|                 data.IdToken = form["id_token"].FirstOrDefault() ?? ""; | ||||
|                 data.State = form["state"].FirstOrDefault(); | ||||
|                 if (form.ContainsKey("user")) | ||||
|                 { | ||||
|                     data.RawData = form["user"].FirstOrDefault(); | ||||
|                 } | ||||
|  | ||||
|                 break; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         return data; | ||||
|     } | ||||
|  | ||||
|  | ||||
| } | ||||
| @@ -105,7 +105,7 @@ public class OidcController( | ||||
|         if (string.IsNullOrEmpty(userInfo.Email)) | ||||
|             throw new ArgumentException("Email is required for account creation"); | ||||
|  | ||||
|         // Check if account exists by email | ||||
|         // Check if an account exists by email | ||||
|         var existingAccount = await accounts.LookupAccount(userInfo.Email); | ||||
|         if (existingAccount != null) | ||||
|         { | ||||
| @@ -116,12 +116,28 @@ public class OidcController( | ||||
|                                           c.ProvidedIdentifier == userInfo.UserId); | ||||
|  | ||||
|             // If no connection exists, create one | ||||
|             if (existingConnection != null) return existingAccount; | ||||
|             if (existingConnection != null) | ||||
|             { | ||||
|                 await db.AccountConnections | ||||
|                     .Where(c => c.AccountId == existingAccount.Id && | ||||
|                                 c.Provider == provider && | ||||
|                                 c.ProvidedIdentifier == userInfo.UserId) | ||||
|                     .ExecuteUpdateAsync(s => s | ||||
|                         .SetProperty(c => c.LastUsedAt, SystemClock.Instance.GetCurrentInstant()) | ||||
|                         .SetProperty(c => c.Meta, userInfo.ToMetadata())); | ||||
|  | ||||
|                 return existingAccount; | ||||
|             } | ||||
|  | ||||
|             var connection = new AccountConnection | ||||
|             { | ||||
|                 AccountId = existingAccount.Id, | ||||
|                 Provider = provider, | ||||
|                 ProvidedIdentifier = userInfo.UserId!, | ||||
|                 AccessToken = userInfo.AccessToken, | ||||
|                 RefreshToken = userInfo.RefreshToken, | ||||
|                 LastUsedAt = SystemClock.Instance.GetCurrentInstant(), | ||||
|                 Meta = userInfo.ToMetadata() | ||||
|             }; | ||||
|  | ||||
|             db.AccountConnections.Add(connection); | ||||
| @@ -139,6 +155,10 @@ public class OidcController( | ||||
|             AccountId = newAccount.Id, | ||||
|             Provider = provider, | ||||
|             ProvidedIdentifier = userInfo.UserId!, | ||||
|             AccessToken = userInfo.AccessToken, | ||||
|             RefreshToken = userInfo.RefreshToken, | ||||
|             LastUsedAt = SystemClock.Instance.GetCurrentInstant(), | ||||
|             Meta = userInfo.ToMetadata() | ||||
|         }; | ||||
|  | ||||
|         db.AccountConnections.Add(newConnection); | ||||
|   | ||||
| @@ -16,4 +16,34 @@ public class OidcUserInfo | ||||
|     public string Provider { get; set; } = ""; | ||||
|     public string? RefreshToken { get; set; } | ||||
|     public string? AccessToken { get; set; } | ||||
|  | ||||
|     public Dictionary<string, object> ToMetadata() | ||||
|     { | ||||
|         var metadata = new Dictionary<string, object>(); | ||||
|  | ||||
|         if (!string.IsNullOrWhiteSpace(UserId)) | ||||
|             metadata["user_id"] = UserId; | ||||
|  | ||||
|         if (!string.IsNullOrWhiteSpace(Email)) | ||||
|             metadata["email"] = Email; | ||||
|  | ||||
|         metadata["email_verified"] = EmailVerified; | ||||
|  | ||||
|         if (!string.IsNullOrWhiteSpace(FirstName)) | ||||
|             metadata["first_name"] = FirstName; | ||||
|  | ||||
|         if (!string.IsNullOrWhiteSpace(LastName)) | ||||
|             metadata["last_name"] = LastName; | ||||
|  | ||||
|         if (!string.IsNullOrWhiteSpace(DisplayName)) | ||||
|             metadata["display_name"] = DisplayName; | ||||
|  | ||||
|         if (!string.IsNullOrWhiteSpace(PreferredUsername)) | ||||
|             metadata["preferred_username"] = PreferredUsername; | ||||
|  | ||||
|         if (!string.IsNullOrWhiteSpace(ProfilePictureUrl)) | ||||
|             metadata["profile_picture_url"] = ProfilePictureUrl; | ||||
|  | ||||
|         return metadata; | ||||
|     } | ||||
| } | ||||
		Reference in New Issue
	
	Block a user