✨ Manual setup account connections
🐛 Fix infinite oauth token reconnect websocket due to missing device id 🐛 Fix IP forwarded headers didn't work
This commit is contained in:
		| @@ -2,13 +2,20 @@ using DysonNetwork.Sphere.Account; | ||||
| using Microsoft.AspNetCore.Authorization; | ||||
| using Microsoft.AspNetCore.Mvc; | ||||
| using Microsoft.EntityFrameworkCore; | ||||
| using NodaTime; | ||||
|  | ||||
| namespace DysonNetwork.Sphere.Auth.OpenId; | ||||
|  | ||||
| [ApiController] | ||||
| [Route("/api/connections")] | ||||
| [Route("/api/accounts/me/connections")] | ||||
| [Authorize] | ||||
| public class ConnectionController(AppDatabase db) : ControllerBase | ||||
| public class ConnectionController( | ||||
|     AppDatabase db,  | ||||
|     IEnumerable<OidcService> oidcServices,  | ||||
|     AccountService accountService,  | ||||
|     AuthService authService, | ||||
|     IClock clock | ||||
| ) : ControllerBase | ||||
| { | ||||
|     [HttpGet] | ||||
|     public async Task<ActionResult<List<AccountConnection>>> GetConnections() | ||||
| @@ -40,4 +47,212 @@ public class ConnectionController(AppDatabase db) : ControllerBase | ||||
|  | ||||
|         return Ok(); | ||||
|     } | ||||
|  | ||||
|     public class ConnectProviderRequest | ||||
|     { | ||||
|         public string Provider { get; set; } = null!; | ||||
|         public string? ReturnUrl { get; set; } | ||||
|     } | ||||
|  | ||||
|     /// <summary> | ||||
|     /// Initiates manual connection to an OAuth provider for the current user | ||||
|     /// </summary> | ||||
|     [HttpPost("connect")] | ||||
|     public async Task<ActionResult<object>> InitiateConnection([FromBody] ConnectProviderRequest request) | ||||
|     { | ||||
|         if (HttpContext.Items["CurrentUser"] is not Account.Account currentUser) | ||||
|             return Unauthorized(); | ||||
|  | ||||
|         var oidcService = oidcServices.FirstOrDefault(s => s.ProviderName.Equals(request.Provider, StringComparison.OrdinalIgnoreCase)); | ||||
|         if (oidcService == null) | ||||
|             return BadRequest($"Provider '{request.Provider}' is not supported"); | ||||
|  | ||||
|         var existingConnection = await db.AccountConnections | ||||
|             .AnyAsync(c => c.AccountId == currentUser.Id && c.Provider == oidcService.ProviderName); | ||||
|  | ||||
|         if (existingConnection) | ||||
|             return BadRequest($"You already have a {request.Provider} connection"); | ||||
|  | ||||
|         var state = Guid.NewGuid().ToString("N"); | ||||
|         var nonce = Guid.NewGuid().ToString("N"); | ||||
|         HttpContext.Session.SetString($"oidc_state_{state}", $"{currentUser.Id}|{request.Provider}|{nonce}"); | ||||
|  | ||||
|         var finalReturnUrl = !string.IsNullOrEmpty(request.ReturnUrl) ? request.ReturnUrl : "/settings/connections"; | ||||
|         HttpContext.Session.SetString($"oidc_return_url_{state}", finalReturnUrl); | ||||
|  | ||||
|         var authUrl = oidcService.GetAuthorizationUrl(state, nonce); | ||||
|  | ||||
|         return Ok(new | ||||
|         { | ||||
|             authUrl, | ||||
|             message = $"Redirect to this URL to connect your {request.Provider} account" | ||||
|         }); | ||||
|     } | ||||
|  | ||||
|     [AllowAnonymous] | ||||
|     [Route("/auth/callback/{provider}")] | ||||
|     [HttpGet, HttpPost] | ||||
|     public async Task<IActionResult> HandleCallback([FromRoute] string provider) | ||||
|     { | ||||
|         var oidcService = oidcServices.FirstOrDefault(s => s.ProviderName.Equals(provider, StringComparison.OrdinalIgnoreCase)); | ||||
|         if (oidcService == null) | ||||
|             return BadRequest($"Provider '{provider}' is not supported."); | ||||
|  | ||||
|         var callbackData = await ExtractCallbackData(Request); | ||||
|         if (callbackData.State == null) | ||||
|             return BadRequest("State parameter is missing."); | ||||
|  | ||||
|         var sessionState = HttpContext.Session.GetString($"oidc_state_{callbackData.State!}"); | ||||
|         HttpContext.Session.Remove($"oidc_state_{callbackData.State}"); | ||||
|  | ||||
|         // If sessionState is present, it's a manual connection flow for an existing user. | ||||
|         if (sessionState != null) | ||||
|         { | ||||
|             var stateParts = sessionState.Split('|'); | ||||
|             if (stateParts.Length != 3 || !stateParts[1].Equals(provider, StringComparison.OrdinalIgnoreCase)) | ||||
|                 return BadRequest("State mismatch."); | ||||
|  | ||||
|             var accountId = Guid.Parse(stateParts[0]); | ||||
|             return await HandleManualConnection(provider, oidcService, callbackData, accountId); | ||||
|         } | ||||
|          | ||||
|         // Otherwise, it's a login or registration flow. | ||||
|         return await HandleLoginOrRegistration(provider, oidcService, callbackData); | ||||
|     } | ||||
|  | ||||
|     private async Task<IActionResult> HandleManualConnection(string provider, OidcService oidcService, OidcCallbackData callbackData, Guid accountId) | ||||
|     { | ||||
|         OidcUserInfo userInfo; | ||||
|         try | ||||
|         { | ||||
|             userInfo = await oidcService.ProcessCallbackAsync(callbackData); | ||||
|         } | ||||
|         catch (Exception ex) | ||||
|         { | ||||
|             return BadRequest($"Error processing callback: {ex.Message}"); | ||||
|         } | ||||
|  | ||||
|         var existingConnection = await db.AccountConnections | ||||
|             .FirstOrDefaultAsync(c => c.Provider.Equals(provider, StringComparison.OrdinalIgnoreCase) && c.ProvidedIdentifier == userInfo.UserId); | ||||
|  | ||||
|         if (existingConnection != null && existingConnection.AccountId != accountId) | ||||
|         { | ||||
|             return BadRequest($"This {provider} account is already linked to another user."); | ||||
|         } | ||||
|  | ||||
|         var userConnection = await db.AccountConnections | ||||
|             .FirstOrDefaultAsync(c => c.AccountId == accountId && c.Provider.Equals(provider, StringComparison.OrdinalIgnoreCase)); | ||||
|  | ||||
|         if (userConnection != null) | ||||
|         { | ||||
|             userConnection.AccessToken = userInfo.AccessToken; | ||||
|             userConnection.RefreshToken = userInfo.RefreshToken; | ||||
|             userConnection.LastUsedAt = clock.GetCurrentInstant(); | ||||
|         } | ||||
|         else | ||||
|         { | ||||
|             db.AccountConnections.Add(new AccountConnection | ||||
|             { | ||||
|                 AccountId = accountId, | ||||
|                 Provider = provider, | ||||
|                 ProvidedIdentifier = userInfo.UserId!, | ||||
|                 AccessToken = userInfo.AccessToken, | ||||
|                 RefreshToken = userInfo.RefreshToken, | ||||
|                 LastUsedAt = clock.GetCurrentInstant() | ||||
|             }); | ||||
|         } | ||||
|  | ||||
|         await db.SaveChangesAsync(); | ||||
|  | ||||
|         var returnUrl = HttpContext.Session.GetString($"oidc_return_url_{callbackData.State}"); | ||||
|         HttpContext.Session.Remove($"oidc_return_url_{callbackData.State}"); | ||||
|  | ||||
|         return Redirect(string.IsNullOrEmpty(returnUrl) ? "/" : returnUrl); | ||||
|     } | ||||
|  | ||||
|     private async Task<IActionResult> HandleLoginOrRegistration(string provider, OidcService oidcService, OidcCallbackData callbackData) | ||||
|     { | ||||
|         OidcUserInfo userInfo; | ||||
|         try | ||||
|         { | ||||
|             userInfo = await oidcService.ProcessCallbackAsync(callbackData); | ||||
|         } | ||||
|         catch (Exception ex) | ||||
|         { | ||||
|             return BadRequest($"Error processing callback: {ex.Message}"); | ||||
|         } | ||||
|  | ||||
|         if (string.IsNullOrEmpty(userInfo.Email) || string.IsNullOrEmpty(userInfo.UserId)) | ||||
|         { | ||||
|             return BadRequest($"Email or user ID is missing from {provider}'s response"); | ||||
|         } | ||||
|  | ||||
|         var connection = await db.AccountConnections | ||||
|             .Include(c => c.Account) | ||||
|             .FirstOrDefaultAsync(c => c.Provider == provider && c.ProvidedIdentifier == userInfo.UserId); | ||||
|  | ||||
|         if (connection != null) | ||||
|         { | ||||
|             // Login existing user | ||||
|             var session = await authService.CreateSessionAsync(connection.Account, clock.GetCurrentInstant()); | ||||
|             var token = authService.CreateToken(session); | ||||
|             return Redirect($"/?token={token}"); | ||||
|         } | ||||
|  | ||||
|         var account = await accountService.LookupAccount(userInfo.Email); | ||||
|         if (account == null) | ||||
|         { | ||||
|             // Register new user | ||||
|             account = await accountService.CreateAccount(userInfo); | ||||
|         } | ||||
|  | ||||
|         if (account == null) | ||||
|         { | ||||
|             return BadRequest("Unable to create or link account."); | ||||
|         } | ||||
|  | ||||
|         // Create connection for new or existing user | ||||
|         var newConnection = new AccountConnection | ||||
|         { | ||||
|             Account = account, | ||||
|             Provider = provider, | ||||
|             ProvidedIdentifier = userInfo.UserId!, | ||||
|             AccessToken = userInfo.AccessToken, | ||||
|             RefreshToken = userInfo.RefreshToken, | ||||
|             LastUsedAt = clock.GetCurrentInstant() | ||||
|         }; | ||||
|         db.AccountConnections.Add(newConnection); | ||||
|  | ||||
|         await db.SaveChangesAsync(); | ||||
|  | ||||
|         var loginSession = await authService.CreateSessionAsync(account, clock.GetCurrentInstant()); | ||||
|         var loginToken = authService.CreateToken(loginSession); | ||||
|         return Redirect($"/?token={loginToken}"); | ||||
|     } | ||||
|  | ||||
|     private async Task<OidcCallbackData> ExtractCallbackData(HttpRequest request) | ||||
|     { | ||||
|         var data = new OidcCallbackData(); | ||||
|         if (request.Method == "GET") | ||||
|         { | ||||
|             data.Code = request.Query["code"].FirstOrDefault() ?? ""; | ||||
|             data.IdToken = request.Query["id_token"].FirstOrDefault() ?? ""; | ||||
|             data.State = request.Query["state"].FirstOrDefault(); | ||||
|         } | ||||
|         else if (request.Method == "POST" && request.HasFormContentType) | ||||
|         { | ||||
|             var form = await request.ReadFormAsync(); | ||||
|             data.Code = form["code"].FirstOrDefault() ?? ""; | ||||
|             data.IdToken = form["id_token"].FirstOrDefault() ?? ""; | ||||
|             data.State = form["state"].FirstOrDefault(); | ||||
|             if (form.ContainsKey("user")) | ||||
|             { | ||||
|                 data.RawData = form["user"].FirstOrDefault(); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         return data; | ||||
|     } | ||||
|  | ||||
|  | ||||
| } | ||||
		Reference in New Issue
	
	Block a user