🐛 Fix PCKE state broke the callback
This commit is contained in:
		| @@ -15,12 +15,13 @@ public class ConnectionController( | ||||
|     IEnumerable<OidcService> oidcServices, | ||||
|     AccountService accounts, | ||||
|     AuthService auth, | ||||
|     ICacheService cacheService | ||||
|     ICacheService cache | ||||
| ) : ControllerBase | ||||
| { | ||||
|     private const string StateCachePrefix = "oidc-state:"; | ||||
|     private const string ReturnUrlCachePrefix = "oidc-returning:"; | ||||
|     private static readonly TimeSpan StateExpiration = TimeSpan.FromMinutes(15); | ||||
|  | ||||
|     [HttpGet] | ||||
|     public async Task<ActionResult<List<AccountConnection>>> GetConnections() | ||||
|     { | ||||
| @@ -34,7 +35,7 @@ public class ConnectionController( | ||||
|                 c.Id, | ||||
|                 c.AccountId, | ||||
|                 c.Provider, | ||||
|                 c.ProvidedIdentifier,  | ||||
|                 c.ProvidedIdentifier, | ||||
|                 c.Meta, | ||||
|                 c.LastUsedAt, | ||||
|                 c.CreatedAt, | ||||
| @@ -150,8 +151,8 @@ public class ConnectionController( | ||||
|         var finalReturnUrl = !string.IsNullOrEmpty(request.ReturnUrl) ? request.ReturnUrl : "/settings/connections"; | ||||
|  | ||||
|         // Store state and return URL in cache | ||||
|         await cacheService.SetAsync($"{StateCachePrefix}{state}", stateValue, StateExpiration); | ||||
|         await cacheService.SetAsync($"{ReturnUrlCachePrefix}{state}", finalReturnUrl, StateExpiration); | ||||
|         await cache.SetAsync($"{StateCachePrefix}{state}", stateValue, StateExpiration); | ||||
|         await cache.SetAsync($"{ReturnUrlCachePrefix}{state}", finalReturnUrl, StateExpiration); | ||||
|  | ||||
|         var authUrl = oidcService.GetAuthorizationUrl(state, nonce); | ||||
|  | ||||
| @@ -177,14 +178,12 @@ public class ConnectionController( | ||||
|  | ||||
|         // Get and validate state from cache | ||||
|         var stateKey = $"{StateCachePrefix}{callbackData.State}"; | ||||
|         var stateValue = await cacheService.GetAsync<string>(stateKey); | ||||
|         var stateValue = await cache.GetAsync<string>(stateKey); | ||||
|         if (string.IsNullOrEmpty(stateValue)) | ||||
|         { | ||||
|             return BadRequest("Invalid or expired state parameter"); | ||||
|         } | ||||
|  | ||||
|         // Remove state from cache to prevent replay attacks | ||||
|         await cacheService.RemoveAsync(stateKey); | ||||
|         await cache.RemoveAsync(stateKey); | ||||
|  | ||||
|         var stateParts = stateValue.Split('|'); | ||||
|         if (stateParts.Length != 3) | ||||
| @@ -228,8 +227,8 @@ public class ConnectionController( | ||||
|  | ||||
|         // Check if the current user already has this provider connected | ||||
|         var userHasProvider = await db.AccountConnections | ||||
|             .AnyAsync(c =>  | ||||
|                 c.AccountId == accountId &&  | ||||
|             .AnyAsync(c => | ||||
|                 c.AccountId == accountId && | ||||
|                 c.Provider.Equals(provider, StringComparison.OrdinalIgnoreCase)); | ||||
|  | ||||
|         if (userHasProvider) | ||||
| @@ -237,7 +236,7 @@ public class ConnectionController( | ||||
|             // Update existing connection with new tokens | ||||
|             var connection = await db.AccountConnections | ||||
|                 .FirstOrDefaultAsync(c => | ||||
|                     c.AccountId == accountId &&  | ||||
|                     c.AccountId == accountId && | ||||
|                     c.Provider.Equals(provider, StringComparison.OrdinalIgnoreCase)); | ||||
|  | ||||
|             if (connection != null) | ||||
| @@ -273,9 +272,9 @@ public class ConnectionController( | ||||
|         } | ||||
|  | ||||
|         // Clean up and redirect | ||||
|         var returnUrlKey = $"{ReturnUrlCachePrefix}{callbackData.State}"; | ||||
|         var returnUrl = await cacheService.GetAsync<string>(returnUrlKey); | ||||
|         await cacheService.RemoveAsync(returnUrlKey); | ||||
|         var returnUrlKey = $"{ReturnUrlCachePrefix}{CleanStateCodeVerifier(callbackData.State)}"; | ||||
|         var returnUrl = await cache.GetAsync<string>(returnUrlKey); | ||||
|         await cache.RemoveAsync(returnUrlKey); | ||||
|  | ||||
|         return Redirect(string.IsNullOrEmpty(returnUrl) ? "/settings/connections" : returnUrl); | ||||
|     } | ||||
| @@ -354,9 +353,7 @@ public class ConnectionController( | ||||
|                 data.IdToken = form["id_token"].FirstOrDefault() ?? ""; | ||||
|                 data.State = form["state"].FirstOrDefault(); | ||||
|                 if (form.ContainsKey("user")) | ||||
|                 { | ||||
|                     data.RawData = form["user"].FirstOrDefault(); | ||||
|                 } | ||||
|  | ||||
|                 break; | ||||
|             } | ||||
| @@ -364,4 +361,9 @@ public class ConnectionController( | ||||
|  | ||||
|         return data; | ||||
|     } | ||||
|  | ||||
|     private static string? CleanStateCodeVerifier(string? og) | ||||
|     { | ||||
|         return og is null ? null : !og.Contains('|') ? og : og.Split('|').First(); | ||||
|     } | ||||
| } | ||||
| @@ -175,8 +175,7 @@ public class GoogleOidcService( | ||||
|  | ||||
|     public string GenerateCodeChallenge(string codeVerifier) | ||||
|     { | ||||
|         using var sha256 = System.Security.Cryptography.SHA256.Create(); | ||||
|         var challengeBytes = sha256.ComputeHash(Encoding.UTF8.GetBytes(codeVerifier)); | ||||
|         var challengeBytes = SHA256.HashData(Encoding.UTF8.GetBytes(codeVerifier)); | ||||
|         return Convert.ToBase64String(challengeBytes) | ||||
|             .Replace('+', '-') | ||||
|             .Replace('/', '_') | ||||
|   | ||||
| @@ -22,7 +22,7 @@ public class OidcController( | ||||
|     private static readonly TimeSpan StateExpiration = TimeSpan.FromMinutes(15); | ||||
|  | ||||
|     [HttpGet("{provider}")] | ||||
|     public async Task<ActionResult> SignIn([FromRoute] string provider, [FromQuery] string? returnUrl = "/") | ||||
|     public async Task<ActionResult> OidcLogin([FromRoute] string provider, [FromQuery] string? returnUrl = "/") | ||||
|     { | ||||
|         try | ||||
|         { | ||||
| @@ -42,13 +42,12 @@ public class OidcController( | ||||
|                 var authUrl = oidcService.GetAuthorizationUrl(state, nonce); | ||||
|                 return Redirect(authUrl); | ||||
|             } | ||||
|             else // Otherwise, proceed with login/registration flow | ||||
|             else // Otherwise, proceed with the login / registration flow | ||||
|             { | ||||
|                 var state = returnUrl; | ||||
|                 var nonce = Guid.NewGuid().ToString(); | ||||
|  | ||||
|                 // The state parameter is the returnUrl. The callback will not find a session state and will treat it as a login. | ||||
|                 var authUrl = oidcService.GetAuthorizationUrl(state ?? "/", nonce); | ||||
|                 var authUrl = oidcService.GetAuthorizationUrl(returnUrl ?? "/", nonce); | ||||
|                 return Redirect(authUrl); | ||||
|             } | ||||
|         } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user