🐛 Fix PCKE state broke the callback

This commit is contained in:
LittleSheep 2025-06-17 21:30:58 +08:00
parent 5e455599fd
commit 3aad515ab8
3 changed files with 22 additions and 22 deletions

View File

@ -15,12 +15,13 @@ public class ConnectionController(
IEnumerable<OidcService> oidcServices,
AccountService accounts,
AuthService auth,
ICacheService cacheService
ICacheService cache
) : ControllerBase
{
private const string StateCachePrefix = "oidc-state:";
private const string ReturnUrlCachePrefix = "oidc-returning:";
private static readonly TimeSpan StateExpiration = TimeSpan.FromMinutes(15);
[HttpGet]
public async Task<ActionResult<List<AccountConnection>>> GetConnections()
{
@ -34,7 +35,7 @@ public class ConnectionController(
c.Id,
c.AccountId,
c.Provider,
c.ProvidedIdentifier,
c.ProvidedIdentifier,
c.Meta,
c.LastUsedAt,
c.CreatedAt,
@ -150,8 +151,8 @@ public class ConnectionController(
var finalReturnUrl = !string.IsNullOrEmpty(request.ReturnUrl) ? request.ReturnUrl : "/settings/connections";
// Store state and return URL in cache
await cacheService.SetAsync($"{StateCachePrefix}{state}", stateValue, StateExpiration);
await cacheService.SetAsync($"{ReturnUrlCachePrefix}{state}", finalReturnUrl, StateExpiration);
await cache.SetAsync($"{StateCachePrefix}{state}", stateValue, StateExpiration);
await cache.SetAsync($"{ReturnUrlCachePrefix}{state}", finalReturnUrl, StateExpiration);
var authUrl = oidcService.GetAuthorizationUrl(state, nonce);
@ -177,14 +178,12 @@ public class ConnectionController(
// Get and validate state from cache
var stateKey = $"{StateCachePrefix}{callbackData.State}";
var stateValue = await cacheService.GetAsync<string>(stateKey);
var stateValue = await cache.GetAsync<string>(stateKey);
if (string.IsNullOrEmpty(stateValue))
{
return BadRequest("Invalid or expired state parameter");
}
// Remove state from cache to prevent replay attacks
await cacheService.RemoveAsync(stateKey);
await cache.RemoveAsync(stateKey);
var stateParts = stateValue.Split('|');
if (stateParts.Length != 3)
@ -228,8 +227,8 @@ public class ConnectionController(
// Check if the current user already has this provider connected
var userHasProvider = await db.AccountConnections
.AnyAsync(c =>
c.AccountId == accountId &&
.AnyAsync(c =>
c.AccountId == accountId &&
c.Provider.Equals(provider, StringComparison.OrdinalIgnoreCase));
if (userHasProvider)
@ -237,7 +236,7 @@ public class ConnectionController(
// Update existing connection with new tokens
var connection = await db.AccountConnections
.FirstOrDefaultAsync(c =>
c.AccountId == accountId &&
c.AccountId == accountId &&
c.Provider.Equals(provider, StringComparison.OrdinalIgnoreCase));
if (connection != null)
@ -273,9 +272,9 @@ public class ConnectionController(
}
// Clean up and redirect
var returnUrlKey = $"{ReturnUrlCachePrefix}{callbackData.State}";
var returnUrl = await cacheService.GetAsync<string>(returnUrlKey);
await cacheService.RemoveAsync(returnUrlKey);
var returnUrlKey = $"{ReturnUrlCachePrefix}{CleanStateCodeVerifier(callbackData.State)}";
var returnUrl = await cache.GetAsync<string>(returnUrlKey);
await cache.RemoveAsync(returnUrlKey);
return Redirect(string.IsNullOrEmpty(returnUrl) ? "/settings/connections" : returnUrl);
}
@ -354,9 +353,7 @@ public class ConnectionController(
data.IdToken = form["id_token"].FirstOrDefault() ?? "";
data.State = form["state"].FirstOrDefault();
if (form.ContainsKey("user"))
{
data.RawData = form["user"].FirstOrDefault();
}
break;
}
@ -364,4 +361,9 @@ public class ConnectionController(
return data;
}
private static string? CleanStateCodeVerifier(string? og)
{
return og is null ? null : !og.Contains('|') ? og : og.Split('|').First();
}
}

View File

@ -175,8 +175,7 @@ public class GoogleOidcService(
public string GenerateCodeChallenge(string codeVerifier)
{
using var sha256 = System.Security.Cryptography.SHA256.Create();
var challengeBytes = sha256.ComputeHash(Encoding.UTF8.GetBytes(codeVerifier));
var challengeBytes = SHA256.HashData(Encoding.UTF8.GetBytes(codeVerifier));
return Convert.ToBase64String(challengeBytes)
.Replace('+', '-')
.Replace('/', '_')

View File

@ -22,7 +22,7 @@ public class OidcController(
private static readonly TimeSpan StateExpiration = TimeSpan.FromMinutes(15);
[HttpGet("{provider}")]
public async Task<ActionResult> SignIn([FromRoute] string provider, [FromQuery] string? returnUrl = "/")
public async Task<ActionResult> OidcLogin([FromRoute] string provider, [FromQuery] string? returnUrl = "/")
{
try
{
@ -42,13 +42,12 @@ public class OidcController(
var authUrl = oidcService.GetAuthorizationUrl(state, nonce);
return Redirect(authUrl);
}
else // Otherwise, proceed with login/registration flow
else // Otherwise, proceed with the login / registration flow
{
var state = returnUrl;
var nonce = Guid.NewGuid().ToString();
// The state parameter is the returnUrl. The callback will not find a session state and will treat it as a login.
var authUrl = oidcService.GetAuthorizationUrl(state ?? "/", nonce);
var authUrl = oidcService.GetAuthorizationUrl(returnUrl ?? "/", nonce);
return Redirect(authUrl);
}
}