♻️ Refactor the state and auth system
This commit is contained in:
		| @@ -15,9 +15,10 @@ public class AppleOidcService( | ||||
|     IConfiguration configuration, | ||||
|     IHttpClientFactory httpClientFactory, | ||||
|     AppDatabase db, | ||||
|     AuthService auth, | ||||
|     ICacheService cache | ||||
| ) | ||||
|     : OidcService(configuration, httpClientFactory, db, cache) | ||||
|     : OidcService(configuration, httpClientFactory, db, auth, cache) | ||||
| { | ||||
|     private readonly IConfiguration _configuration = configuration; | ||||
|     private readonly IHttpClientFactory _httpClientFactory = httpClientFactory; | ||||
|   | ||||
| @@ -176,23 +176,52 @@ public class ConnectionController( | ||||
|         if (callbackData.State == null) | ||||
|             return BadRequest("State parameter is missing."); | ||||
|  | ||||
|         // Get and validate state from cache | ||||
|         // Get the state from the cache | ||||
|         var stateKey = $"{StateCachePrefix}{callbackData.State}"; | ||||
|         var stateValue = await cache.GetAsync<string>(stateKey); | ||||
|         if (string.IsNullOrEmpty(stateValue)) | ||||
|             return BadRequest("Invalid or expired state parameter"); | ||||
|  | ||||
|         // Remove state from cache to prevent replay attacks | ||||
|          | ||||
|         // Try to get the state as OidcState first (new format) | ||||
|         var oidcState = await cache.GetAsync<OidcState>(stateKey); | ||||
|          | ||||
|         // If not found, try to get as string (legacy format) | ||||
|         if (oidcState == null) | ||||
|         { | ||||
|             var stateValue = await cache.GetAsync<string>(stateKey); | ||||
|             if (string.IsNullOrEmpty(stateValue) || !OidcState.TryParse(stateValue, out oidcState) || oidcState == null) | ||||
|                 return BadRequest("Invalid or expired state parameter"); | ||||
|         } | ||||
|          | ||||
|         // Remove the state from cache to prevent replay attacks | ||||
|         await cache.RemoveAsync(stateKey); | ||||
|  | ||||
|         var stateParts = stateValue.Split('|'); | ||||
|         if (stateParts.Length != 3) | ||||
|         // Handle the flow based on state type | ||||
|         if (oidcState.FlowType == OidcFlowType.Connect && oidcState.AccountId.HasValue) | ||||
|         { | ||||
|             return BadRequest("Invalid state format"); | ||||
|             // Connection flow | ||||
|             if (oidcState.DeviceId != null) | ||||
|             { | ||||
|                 callbackData.State = oidcState.DeviceId; | ||||
|             } | ||||
|             return await HandleManualConnection(provider, oidcService, callbackData, oidcState.AccountId.Value); | ||||
|         } | ||||
|         else if (oidcState.FlowType == OidcFlowType.Login) | ||||
|         { | ||||
|             // Login/Registration flow | ||||
|             if (!string.IsNullOrEmpty(oidcState.DeviceId)) | ||||
|             { | ||||
|                 callbackData.State = oidcState.DeviceId; | ||||
|             } | ||||
|  | ||||
|             // Store return URL if provided | ||||
|             if (!string.IsNullOrEmpty(oidcState.ReturnUrl) && oidcState.ReturnUrl != "/") | ||||
|             { | ||||
|                 var returnUrlKey = $"{ReturnUrlCachePrefix}{callbackData.State}"; | ||||
|                 await cache.SetAsync(returnUrlKey, oidcState.ReturnUrl, StateExpiration); | ||||
|             } | ||||
|  | ||||
|             return await HandleLoginOrRegistration(provider, oidcService, callbackData); | ||||
|         } | ||||
|  | ||||
|         var accountId = Guid.Parse(stateParts[0]); | ||||
|         return await HandleManualConnection(provider, oidcService, callbackData, accountId); | ||||
|         return BadRequest("Unsupported flow type"); | ||||
|     } | ||||
|  | ||||
|     private async Task<IActionResult> HandleManualConnection( | ||||
| @@ -219,6 +248,9 @@ public class ConnectionController( | ||||
|             return BadRequest($"{provider} did not return a valid user identifier."); | ||||
|         } | ||||
|  | ||||
|         // Extract device ID from the callback state if available | ||||
|         var deviceId = !string.IsNullOrEmpty(callbackData.State) ? callbackData.State : string.Empty; | ||||
|  | ||||
|         // Check if this provider account is already connected to any user | ||||
|         var existingConnection = await db.AccountConnections | ||||
|             .FirstOrDefaultAsync(c => | ||||
| @@ -314,9 +346,16 @@ public class ConnectionController( | ||||
|         if (connection != null) | ||||
|         { | ||||
|             // Login existing user | ||||
|             var session = await auth.CreateSessionAsync(connection.Account, clock.GetCurrentInstant()); | ||||
|             var token = auth.CreateToken(session); | ||||
|             return Redirect($"/auth/token?token={token}"); | ||||
|             var deviceId = !string.IsNullOrEmpty(callbackData.State) ?  | ||||
|                 callbackData.State.Split('|').FirstOrDefault() :  | ||||
|                 string.Empty; | ||||
|                  | ||||
|             var challenge = await oidcService.CreateChallengeForUserAsync( | ||||
|                 userInfo,  | ||||
|                 connection.Account,  | ||||
|                 HttpContext,  | ||||
|                 deviceId ?? string.Empty); | ||||
|             return Redirect($"/auth/callback?context={challenge.Id}"); | ||||
|         } | ||||
|  | ||||
|         // Register new user | ||||
|   | ||||
| @@ -8,9 +8,10 @@ public class DiscordOidcService( | ||||
|     IConfiguration configuration, | ||||
|     IHttpClientFactory httpClientFactory, | ||||
|     AppDatabase db, | ||||
|     AuthService auth, | ||||
|     ICacheService cache | ||||
| ) | ||||
|     : OidcService(configuration, httpClientFactory, db, cache) | ||||
|     : OidcService(configuration, httpClientFactory, db, auth, cache) | ||||
| { | ||||
|     public override string ProviderName => "Discord"; | ||||
|     protected override string DiscoveryEndpoint => ""; // Discord doesn't have a standard OIDC discovery endpoint | ||||
|   | ||||
| @@ -8,9 +8,10 @@ public class GitHubOidcService( | ||||
|     IConfiguration configuration, | ||||
|     IHttpClientFactory httpClientFactory, | ||||
|     AppDatabase db, | ||||
|     AuthService auth, | ||||
|     ICacheService cache | ||||
| ) | ||||
|     : OidcService(configuration, httpClientFactory, db, cache) | ||||
|     : OidcService(configuration, httpClientFactory, db, auth, cache) | ||||
| { | ||||
|     public override string ProviderName => "GitHub"; | ||||
|     protected override string DiscoveryEndpoint => ""; // GitHub doesn't have a standard OIDC discovery endpoint | ||||
|   | ||||
| @@ -11,9 +11,10 @@ public class GoogleOidcService( | ||||
|     IConfiguration configuration, | ||||
|     IHttpClientFactory httpClientFactory, | ||||
|     AppDatabase db, | ||||
|     AuthService auth, | ||||
|     ICacheService cache | ||||
| ) | ||||
|     : OidcService(configuration, httpClientFactory, db, cache) | ||||
|     : OidcService(configuration, httpClientFactory, db, auth, cache) | ||||
| { | ||||
|     private readonly IHttpClientFactory _httpClientFactory = httpClientFactory; | ||||
|  | ||||
| @@ -31,8 +32,6 @@ public class GoogleOidcService( | ||||
|             throw new InvalidOperationException("Authorization endpoint not found in discovery document"); | ||||
|         } | ||||
|  | ||||
|         // PKCE code removed: no code verifier/challenge generated | ||||
|  | ||||
|         var queryParams = new Dictionary<string, string> | ||||
|         { | ||||
|             { "client_id", config.ClientId }, | ||||
|   | ||||
| @@ -8,9 +8,10 @@ public class MicrosoftOidcService( | ||||
|     IConfiguration configuration, | ||||
|     IHttpClientFactory httpClientFactory, | ||||
|     AppDatabase db, | ||||
|     AuthService auth, | ||||
|     ICacheService cache | ||||
| ) | ||||
|     : OidcService(configuration, httpClientFactory, db, cache) | ||||
|     : OidcService(configuration, httpClientFactory, db, auth, cache) | ||||
| { | ||||
|     public override string ProviderName => "Microsoft"; | ||||
|  | ||||
|   | ||||
| @@ -13,7 +13,6 @@ public class OidcController( | ||||
|     IServiceProvider serviceProvider, | ||||
|     AppDatabase db, | ||||
|     AccountService accounts, | ||||
|     AuthService auth, | ||||
|     ICacheService cache | ||||
| ) | ||||
|     : ControllerBase | ||||
| @@ -22,7 +21,11 @@ public class OidcController( | ||||
|     private static readonly TimeSpan StateExpiration = TimeSpan.FromMinutes(15); | ||||
|  | ||||
|     [HttpGet("{provider}")] | ||||
|     public async Task<ActionResult> OidcLogin([FromRoute] string provider, [FromQuery] string? returnUrl = "/") | ||||
|     public async Task<ActionResult> OidcLogin( | ||||
|         [FromRoute] string provider, | ||||
|         [FromQuery] string? returnUrl = "/", | ||||
|         [FromHeader(Name = "X-Device-Id")] string? deviceId = null | ||||
|     ) | ||||
|     { | ||||
|         try | ||||
|         { | ||||
| @@ -34,9 +37,9 @@ public class OidcController( | ||||
|                 var state = Guid.NewGuid().ToString(); | ||||
|                 var nonce = Guid.NewGuid().ToString(); | ||||
|  | ||||
|                 // Store user's ID, provider, and nonce in cache. The callback will use this. | ||||
|                 var stateValue = $"{currentUser.Id}|{provider}|{nonce}"; | ||||
|                 await cache.SetAsync($"{StateCachePrefix}{state}", stateValue, StateExpiration); | ||||
|                 // Create and store connection state | ||||
|                 var oidcState = OidcState.ForConnection(currentUser.Id, provider, nonce, deviceId); | ||||
|                 await cache.SetAsync($"{StateCachePrefix}{state}", oidcState, StateExpiration); | ||||
|  | ||||
|                 // The state parameter sent to the provider is the GUID key for the cache. | ||||
|                 var authUrl = oidcService.GetAuthorizationUrl(state, nonce); | ||||
| @@ -45,9 +48,12 @@ public class OidcController( | ||||
|             else // Otherwise, proceed with the login / registration flow | ||||
|             { | ||||
|                 var nonce = Guid.NewGuid().ToString(); | ||||
|                 var state = Guid.NewGuid().ToString(); | ||||
|  | ||||
|                 // The state parameter is the returnUrl. The callback will not find a session state and will treat it as a login. | ||||
|                 var authUrl = oidcService.GetAuthorizationUrl(returnUrl ?? "/", nonce); | ||||
|                 // Create login state with return URL and device ID | ||||
|                 var oidcState = OidcState.ForLogin(returnUrl ?? "/", deviceId); | ||||
|                 await cache.SetAsync($"{StateCachePrefix}{state}", oidcState, StateExpiration); | ||||
|                 var authUrl = oidcService.GetAuthorizationUrl(state, nonce); | ||||
|                 return Redirect(authUrl); | ||||
|             } | ||||
|         } | ||||
| @@ -62,7 +68,7 @@ public class OidcController( | ||||
|     /// Handles Apple authentication directly from mobile apps | ||||
|     /// </summary> | ||||
|     [HttpPost("apple/mobile")] | ||||
|     public async Task<ActionResult<AuthController.TokenExchangeResponse>> AppleMobileSignIn( | ||||
|     public async Task<ActionResult<Challenge>> AppleMobileLogin( | ||||
|         [FromBody] AppleMobileSignInRequest request) | ||||
|     { | ||||
|         try | ||||
| @@ -85,17 +91,14 @@ public class OidcController( | ||||
|             var account = await FindOrCreateAccount(userInfo, "apple"); | ||||
|  | ||||
|             // Create session using the OIDC service | ||||
|             var session = await appleService.CreateSessionForUserAsync( | ||||
|             var challenge = await appleService.CreateChallengeForUserAsync( | ||||
|                 userInfo, | ||||
|                 account, | ||||
|                 HttpContext, | ||||
|                 request.DeviceId | ||||
|             ); | ||||
|  | ||||
|             // Generate token using existing auth service | ||||
|             var token = auth.CreateToken(session); | ||||
|  | ||||
|             return Ok(new AuthController.TokenExchangeResponse { Token = token }); | ||||
|             return Ok(challenge); | ||||
|         } | ||||
|         catch (SecurityTokenValidationException ex) | ||||
|         { | ||||
|   | ||||
| @@ -16,6 +16,7 @@ public abstract class OidcService( | ||||
|     IConfiguration configuration, | ||||
|     IHttpClientFactory httpClientFactory, | ||||
|     AppDatabase db, | ||||
|     AuthService auth, | ||||
|     ICacheService cache | ||||
| ) | ||||
| { | ||||
| @@ -187,7 +188,7 @@ public abstract class OidcService( | ||||
|     /// Creates a challenge and session for an authenticated user | ||||
|     /// Also creates or updates the account connection | ||||
|     /// </summary> | ||||
|     public async Task<Session> CreateSessionForUserAsync( | ||||
|     public async Task<Challenge> CreateChallengeForUserAsync( | ||||
|         OidcUserInfo userInfo, | ||||
|         Account.Account account, | ||||
|         HttpContext request, | ||||
| @@ -220,8 +221,7 @@ public abstract class OidcService( | ||||
|         var challenge = new Challenge | ||||
|         { | ||||
|             ExpiredAt = now.Plus(Duration.FromHours(1)), | ||||
|             StepTotal = 1, | ||||
|             StepRemain = 0, // Already verified by provider | ||||
|             StepTotal = await auth.DetectChallengeRisk(request.Request, account), | ||||
|             Type = ChallengeType.Oidc, | ||||
|             Platform = ChallengePlatform.Unidentified, | ||||
|             Audiences = [ProviderName], | ||||
| @@ -231,22 +231,13 @@ public abstract class OidcService( | ||||
|             IpAddress = request.Connection.RemoteIpAddress?.ToString() ?? null, | ||||
|             UserAgent = request.Request.Headers.UserAgent, | ||||
|         }; | ||||
|         challenge.StepRemain--; | ||||
|         if (challenge.StepRemain < 0) challenge.StepRemain = 0; | ||||
|  | ||||
|         await Db.AuthChallenges.AddAsync(challenge); | ||||
|  | ||||
|         // Create a session | ||||
|         var session = new Session | ||||
|         { | ||||
|             AccountId = account.Id, | ||||
|             CreatedAt = now, | ||||
|             LastGrantedAt = now, | ||||
|             Challenge = challenge | ||||
|         }; | ||||
|  | ||||
|         await Db.AuthSessions.AddAsync(session); | ||||
|         await Db.SaveChangesAsync(); | ||||
|  | ||||
|         return session; | ||||
|         return challenge; | ||||
|     } | ||||
| } | ||||
|  | ||||
|   | ||||
							
								
								
									
										189
									
								
								DysonNetwork.Sphere/Auth/OpenId/OidcState.cs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										189
									
								
								DysonNetwork.Sphere/Auth/OpenId/OidcState.cs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,189 @@ | ||||
| using System.Text.Json; | ||||
| using System.Text.Json.Serialization; | ||||
|  | ||||
| namespace DysonNetwork.Sphere.Auth.OpenId; | ||||
|  | ||||
| /// <summary> | ||||
| /// Represents the state parameter used in OpenID Connect flows. | ||||
| /// Handles serialization and deserialization of the state parameter. | ||||
| /// </summary> | ||||
| public class OidcState | ||||
| { | ||||
|     /// <summary> | ||||
|     /// The type of OIDC flow (login or connect). | ||||
|     /// </summary> | ||||
|     public OidcFlowType FlowType { get; set; } | ||||
|  | ||||
|     /// <summary> | ||||
|     /// The account ID (for connect flow). | ||||
|     /// </summary> | ||||
|     public Guid? AccountId { get; set; } | ||||
|  | ||||
|  | ||||
|     /// <summary> | ||||
|     /// The OIDC provider name. | ||||
|     /// </summary> | ||||
|     public string? Provider { get; set; } | ||||
|  | ||||
|  | ||||
|     /// <summary> | ||||
|     /// The nonce for CSRF protection. | ||||
|     /// </summary> | ||||
|     public string? Nonce { get; set; } | ||||
|  | ||||
|  | ||||
|     /// <summary> | ||||
|     /// The device ID for the authentication request. | ||||
|     /// </summary> | ||||
|     public string? DeviceId { get; set; } | ||||
|  | ||||
|  | ||||
|     /// <summary> | ||||
|     /// The return URL after authentication (for login flow). | ||||
|     /// </summary> | ||||
|     public string? ReturnUrl { get; set; } | ||||
|  | ||||
|  | ||||
|     /// <summary> | ||||
|     /// Creates a new OidcState for a connection flow. | ||||
|     /// </summary> | ||||
|     public static OidcState ForConnection(Guid accountId, string provider, string nonce, string? deviceId = null) | ||||
|     { | ||||
|         return new OidcState | ||||
|         { | ||||
|             FlowType = OidcFlowType.Connect, | ||||
|             AccountId = accountId, | ||||
|             Provider = provider, | ||||
|             Nonce = nonce, | ||||
|             DeviceId = deviceId | ||||
|         }; | ||||
|     } | ||||
|  | ||||
|     /// <summary> | ||||
|     /// Creates a new OidcState for a login flow. | ||||
|     /// </summary> | ||||
|     public static OidcState ForLogin(string returnUrl = "/", string? deviceId = null) | ||||
|     { | ||||
|         return new OidcState | ||||
|         { | ||||
|             FlowType = OidcFlowType.Login, | ||||
|             ReturnUrl = returnUrl, | ||||
|             DeviceId = deviceId | ||||
|         }; | ||||
|     } | ||||
|  | ||||
|     /// <summary> | ||||
|     /// The version of the state format. | ||||
|     /// </summary> | ||||
|     public int Version { get; set; } = 1; | ||||
|  | ||||
|     /// <summary> | ||||
|     /// Serializes the state to a JSON string for use in OIDC flows. | ||||
|     /// </summary> | ||||
|     public string Serialize() | ||||
|     { | ||||
|         return JsonSerializer.Serialize(this, new JsonSerializerOptions | ||||
|         { | ||||
|             PropertyNamingPolicy = JsonNamingPolicy.CamelCase, | ||||
|             DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull | ||||
|         }); | ||||
|     } | ||||
|  | ||||
|     /// <summary> | ||||
|     /// Attempts to parse a state string into an OidcState object. | ||||
|     /// </summary> | ||||
|     public static bool TryParse(string? stateString, out OidcState? state) | ||||
|     { | ||||
|         state = null; | ||||
|  | ||||
|         if (string.IsNullOrEmpty(stateString)) | ||||
|             return false; | ||||
|  | ||||
|         try | ||||
|         { | ||||
|             // First try to parse as JSON | ||||
|             try | ||||
|             { | ||||
|                 state = JsonSerializer.Deserialize<OidcState>(stateString); | ||||
|                 return state != null; | ||||
|             } | ||||
|             catch (JsonException) | ||||
|             { | ||||
|                 // Not a JSON string, try legacy format for backward compatibility | ||||
|                 return TryParseLegacyFormat(stateString, out state); | ||||
|             } | ||||
|         } | ||||
|         catch | ||||
|         { | ||||
|             return false; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     private static bool TryParseLegacyFormat(string stateString, out OidcState? state) | ||||
|     { | ||||
|         state = null; | ||||
|         var parts = stateString.Split('|'); | ||||
|  | ||||
|         // Check for connection flow format: {accountId}|{provider}|{nonce}|{deviceId}|connect | ||||
|         if (parts.Length >= 5 && | ||||
|             Guid.TryParse(parts[0], out var accountId) && | ||||
|             string.Equals(parts[^1], "connect", StringComparison.OrdinalIgnoreCase)) | ||||
|         { | ||||
|             state = new OidcState | ||||
|             { | ||||
|                 FlowType = OidcFlowType.Connect, | ||||
|                 AccountId = accountId, | ||||
|                 Provider = parts[1], | ||||
|                 Nonce = parts[2], | ||||
|                 DeviceId = parts.Length >= 4 && !string.IsNullOrEmpty(parts[3]) ? parts[3] : null | ||||
|             }; | ||||
|             return true; | ||||
|         } | ||||
|  | ||||
|         // Check for login flow format: {returnUrl}|{deviceId}|login | ||||
|         if (parts.Length >= 2 && | ||||
|             parts.Length <= 3 && | ||||
|             (parts.Length < 3 || string.Equals(parts[^1], "login", StringComparison.OrdinalIgnoreCase))) | ||||
|         { | ||||
|             state = new OidcState | ||||
|             { | ||||
|                 FlowType = OidcFlowType.Login, | ||||
|                 ReturnUrl = parts[0], | ||||
|                 DeviceId = parts.Length >= 2 && !string.IsNullOrEmpty(parts[1]) ? parts[1] : null | ||||
|             }; | ||||
|             return true; | ||||
|         } | ||||
|  | ||||
|         // Legacy format support (for backward compatibility) | ||||
|         if (parts.Length == 1) | ||||
|         { | ||||
|             state = new OidcState | ||||
|             { | ||||
|                 FlowType = OidcFlowType.Login, | ||||
|                 ReturnUrl = parts[0], | ||||
|                 DeviceId = null | ||||
|             }; | ||||
|             return true; | ||||
|         } | ||||
|  | ||||
|  | ||||
|         return false; | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// <summary> | ||||
| /// Represents the type of OIDC flow. | ||||
| /// </summary> | ||||
| public enum OidcFlowType | ||||
| { | ||||
|     /// <summary> | ||||
|     /// Login or registration flow. | ||||
|     /// </summary> | ||||
|     Login, | ||||
|  | ||||
|  | ||||
|     /// <summary> | ||||
|     /// Account connection flow. | ||||
|     /// </summary> | ||||
|     Connect | ||||
| } | ||||
		Reference in New Issue
	
	Block a user