♻️ Refactor the state and auth system

This commit is contained in:
2025-06-18 01:23:06 +08:00
parent aba0f6b5e2
commit 2a5926a94a
9 changed files with 274 additions and 49 deletions

View File

@ -176,23 +176,52 @@ public class ConnectionController(
if (callbackData.State == null)
return BadRequest("State parameter is missing.");
// Get and validate state from cache
// Get the state from the cache
var stateKey = $"{StateCachePrefix}{callbackData.State}";
var stateValue = await cache.GetAsync<string>(stateKey);
if (string.IsNullOrEmpty(stateValue))
return BadRequest("Invalid or expired state parameter");
// Remove state from cache to prevent replay attacks
// Try to get the state as OidcState first (new format)
var oidcState = await cache.GetAsync<OidcState>(stateKey);
// If not found, try to get as string (legacy format)
if (oidcState == null)
{
var stateValue = await cache.GetAsync<string>(stateKey);
if (string.IsNullOrEmpty(stateValue) || !OidcState.TryParse(stateValue, out oidcState) || oidcState == null)
return BadRequest("Invalid or expired state parameter");
}
// Remove the state from cache to prevent replay attacks
await cache.RemoveAsync(stateKey);
var stateParts = stateValue.Split('|');
if (stateParts.Length != 3)
// Handle the flow based on state type
if (oidcState.FlowType == OidcFlowType.Connect && oidcState.AccountId.HasValue)
{
return BadRequest("Invalid state format");
// Connection flow
if (oidcState.DeviceId != null)
{
callbackData.State = oidcState.DeviceId;
}
return await HandleManualConnection(provider, oidcService, callbackData, oidcState.AccountId.Value);
}
else if (oidcState.FlowType == OidcFlowType.Login)
{
// Login/Registration flow
if (!string.IsNullOrEmpty(oidcState.DeviceId))
{
callbackData.State = oidcState.DeviceId;
}
// Store return URL if provided
if (!string.IsNullOrEmpty(oidcState.ReturnUrl) && oidcState.ReturnUrl != "/")
{
var returnUrlKey = $"{ReturnUrlCachePrefix}{callbackData.State}";
await cache.SetAsync(returnUrlKey, oidcState.ReturnUrl, StateExpiration);
}
return await HandleLoginOrRegistration(provider, oidcService, callbackData);
}
var accountId = Guid.Parse(stateParts[0]);
return await HandleManualConnection(provider, oidcService, callbackData, accountId);
return BadRequest("Unsupported flow type");
}
private async Task<IActionResult> HandleManualConnection(
@ -219,6 +248,9 @@ public class ConnectionController(
return BadRequest($"{provider} did not return a valid user identifier.");
}
// Extract device ID from the callback state if available
var deviceId = !string.IsNullOrEmpty(callbackData.State) ? callbackData.State : string.Empty;
// Check if this provider account is already connected to any user
var existingConnection = await db.AccountConnections
.FirstOrDefaultAsync(c =>
@ -314,9 +346,16 @@ public class ConnectionController(
if (connection != null)
{
// Login existing user
var session = await auth.CreateSessionAsync(connection.Account, clock.GetCurrentInstant());
var token = auth.CreateToken(session);
return Redirect($"/auth/token?token={token}");
var deviceId = !string.IsNullOrEmpty(callbackData.State) ?
callbackData.State.Split('|').FirstOrDefault() :
string.Empty;
var challenge = await oidcService.CreateChallengeForUserAsync(
userInfo,
connection.Account,
HttpContext,
deviceId ?? string.Empty);
return Redirect($"/auth/callback?context={challenge.Id}");
}
// Register new user