♻️ Refactor the state and auth system

This commit is contained in:
LittleSheep 2025-06-18 01:23:06 +08:00
parent aba0f6b5e2
commit 2a5926a94a
9 changed files with 274 additions and 49 deletions

View File

@ -15,9 +15,10 @@ public class AppleOidcService(
IConfiguration configuration, IConfiguration configuration,
IHttpClientFactory httpClientFactory, IHttpClientFactory httpClientFactory,
AppDatabase db, AppDatabase db,
AuthService auth,
ICacheService cache ICacheService cache
) )
: OidcService(configuration, httpClientFactory, db, cache) : OidcService(configuration, httpClientFactory, db, auth, cache)
{ {
private readonly IConfiguration _configuration = configuration; private readonly IConfiguration _configuration = configuration;
private readonly IHttpClientFactory _httpClientFactory = httpClientFactory; private readonly IHttpClientFactory _httpClientFactory = httpClientFactory;

View File

@ -176,23 +176,52 @@ public class ConnectionController(
if (callbackData.State == null) if (callbackData.State == null)
return BadRequest("State parameter is missing."); return BadRequest("State parameter is missing.");
// Get and validate state from cache // Get the state from the cache
var stateKey = $"{StateCachePrefix}{callbackData.State}"; var stateKey = $"{StateCachePrefix}{callbackData.State}";
var stateValue = await cache.GetAsync<string>(stateKey);
if (string.IsNullOrEmpty(stateValue)) // Try to get the state as OidcState first (new format)
return BadRequest("Invalid or expired state parameter"); var oidcState = await cache.GetAsync<OidcState>(stateKey);
// Remove state from cache to prevent replay attacks // If not found, try to get as string (legacy format)
if (oidcState == null)
{
var stateValue = await cache.GetAsync<string>(stateKey);
if (string.IsNullOrEmpty(stateValue) || !OidcState.TryParse(stateValue, out oidcState) || oidcState == null)
return BadRequest("Invalid or expired state parameter");
}
// Remove the state from cache to prevent replay attacks
await cache.RemoveAsync(stateKey); await cache.RemoveAsync(stateKey);
var stateParts = stateValue.Split('|'); // Handle the flow based on state type
if (stateParts.Length != 3) if (oidcState.FlowType == OidcFlowType.Connect && oidcState.AccountId.HasValue)
{ {
return BadRequest("Invalid state format"); // Connection flow
if (oidcState.DeviceId != null)
{
callbackData.State = oidcState.DeviceId;
}
return await HandleManualConnection(provider, oidcService, callbackData, oidcState.AccountId.Value);
}
else if (oidcState.FlowType == OidcFlowType.Login)
{
// Login/Registration flow
if (!string.IsNullOrEmpty(oidcState.DeviceId))
{
callbackData.State = oidcState.DeviceId;
}
// Store return URL if provided
if (!string.IsNullOrEmpty(oidcState.ReturnUrl) && oidcState.ReturnUrl != "/")
{
var returnUrlKey = $"{ReturnUrlCachePrefix}{callbackData.State}";
await cache.SetAsync(returnUrlKey, oidcState.ReturnUrl, StateExpiration);
}
return await HandleLoginOrRegistration(provider, oidcService, callbackData);
} }
var accountId = Guid.Parse(stateParts[0]); return BadRequest("Unsupported flow type");
return await HandleManualConnection(provider, oidcService, callbackData, accountId);
} }
private async Task<IActionResult> HandleManualConnection( private async Task<IActionResult> HandleManualConnection(
@ -219,6 +248,9 @@ public class ConnectionController(
return BadRequest($"{provider} did not return a valid user identifier."); return BadRequest($"{provider} did not return a valid user identifier.");
} }
// Extract device ID from the callback state if available
var deviceId = !string.IsNullOrEmpty(callbackData.State) ? callbackData.State : string.Empty;
// Check if this provider account is already connected to any user // Check if this provider account is already connected to any user
var existingConnection = await db.AccountConnections var existingConnection = await db.AccountConnections
.FirstOrDefaultAsync(c => .FirstOrDefaultAsync(c =>
@ -314,9 +346,16 @@ public class ConnectionController(
if (connection != null) if (connection != null)
{ {
// Login existing user // Login existing user
var session = await auth.CreateSessionAsync(connection.Account, clock.GetCurrentInstant()); var deviceId = !string.IsNullOrEmpty(callbackData.State) ?
var token = auth.CreateToken(session); callbackData.State.Split('|').FirstOrDefault() :
return Redirect($"/auth/token?token={token}"); string.Empty;
var challenge = await oidcService.CreateChallengeForUserAsync(
userInfo,
connection.Account,
HttpContext,
deviceId ?? string.Empty);
return Redirect($"/auth/callback?context={challenge.Id}");
} }
// Register new user // Register new user

View File

@ -8,9 +8,10 @@ public class DiscordOidcService(
IConfiguration configuration, IConfiguration configuration,
IHttpClientFactory httpClientFactory, IHttpClientFactory httpClientFactory,
AppDatabase db, AppDatabase db,
AuthService auth,
ICacheService cache ICacheService cache
) )
: OidcService(configuration, httpClientFactory, db, cache) : OidcService(configuration, httpClientFactory, db, auth, cache)
{ {
public override string ProviderName => "Discord"; public override string ProviderName => "Discord";
protected override string DiscoveryEndpoint => ""; // Discord doesn't have a standard OIDC discovery endpoint protected override string DiscoveryEndpoint => ""; // Discord doesn't have a standard OIDC discovery endpoint

View File

@ -8,9 +8,10 @@ public class GitHubOidcService(
IConfiguration configuration, IConfiguration configuration,
IHttpClientFactory httpClientFactory, IHttpClientFactory httpClientFactory,
AppDatabase db, AppDatabase db,
AuthService auth,
ICacheService cache ICacheService cache
) )
: OidcService(configuration, httpClientFactory, db, cache) : OidcService(configuration, httpClientFactory, db, auth, cache)
{ {
public override string ProviderName => "GitHub"; public override string ProviderName => "GitHub";
protected override string DiscoveryEndpoint => ""; // GitHub doesn't have a standard OIDC discovery endpoint protected override string DiscoveryEndpoint => ""; // GitHub doesn't have a standard OIDC discovery endpoint

View File

@ -11,9 +11,10 @@ public class GoogleOidcService(
IConfiguration configuration, IConfiguration configuration,
IHttpClientFactory httpClientFactory, IHttpClientFactory httpClientFactory,
AppDatabase db, AppDatabase db,
AuthService auth,
ICacheService cache ICacheService cache
) )
: OidcService(configuration, httpClientFactory, db, cache) : OidcService(configuration, httpClientFactory, db, auth, cache)
{ {
private readonly IHttpClientFactory _httpClientFactory = httpClientFactory; private readonly IHttpClientFactory _httpClientFactory = httpClientFactory;
@ -31,8 +32,6 @@ public class GoogleOidcService(
throw new InvalidOperationException("Authorization endpoint not found in discovery document"); throw new InvalidOperationException("Authorization endpoint not found in discovery document");
} }
// PKCE code removed: no code verifier/challenge generated
var queryParams = new Dictionary<string, string> var queryParams = new Dictionary<string, string>
{ {
{ "client_id", config.ClientId }, { "client_id", config.ClientId },

View File

@ -8,9 +8,10 @@ public class MicrosoftOidcService(
IConfiguration configuration, IConfiguration configuration,
IHttpClientFactory httpClientFactory, IHttpClientFactory httpClientFactory,
AppDatabase db, AppDatabase db,
AuthService auth,
ICacheService cache ICacheService cache
) )
: OidcService(configuration, httpClientFactory, db, cache) : OidcService(configuration, httpClientFactory, db, auth, cache)
{ {
public override string ProviderName => "Microsoft"; public override string ProviderName => "Microsoft";

View File

@ -13,7 +13,6 @@ public class OidcController(
IServiceProvider serviceProvider, IServiceProvider serviceProvider,
AppDatabase db, AppDatabase db,
AccountService accounts, AccountService accounts,
AuthService auth,
ICacheService cache ICacheService cache
) )
: ControllerBase : ControllerBase
@ -22,7 +21,11 @@ public class OidcController(
private static readonly TimeSpan StateExpiration = TimeSpan.FromMinutes(15); private static readonly TimeSpan StateExpiration = TimeSpan.FromMinutes(15);
[HttpGet("{provider}")] [HttpGet("{provider}")]
public async Task<ActionResult> OidcLogin([FromRoute] string provider, [FromQuery] string? returnUrl = "/") public async Task<ActionResult> OidcLogin(
[FromRoute] string provider,
[FromQuery] string? returnUrl = "/",
[FromHeader(Name = "X-Device-Id")] string? deviceId = null
)
{ {
try try
{ {
@ -34,9 +37,9 @@ public class OidcController(
var state = Guid.NewGuid().ToString(); var state = Guid.NewGuid().ToString();
var nonce = Guid.NewGuid().ToString(); var nonce = Guid.NewGuid().ToString();
// Store user's ID, provider, and nonce in cache. The callback will use this. // Create and store connection state
var stateValue = $"{currentUser.Id}|{provider}|{nonce}"; var oidcState = OidcState.ForConnection(currentUser.Id, provider, nonce, deviceId);
await cache.SetAsync($"{StateCachePrefix}{state}", stateValue, StateExpiration); await cache.SetAsync($"{StateCachePrefix}{state}", oidcState, StateExpiration);
// The state parameter sent to the provider is the GUID key for the cache. // The state parameter sent to the provider is the GUID key for the cache.
var authUrl = oidcService.GetAuthorizationUrl(state, nonce); var authUrl = oidcService.GetAuthorizationUrl(state, nonce);
@ -45,9 +48,12 @@ public class OidcController(
else // Otherwise, proceed with the login / registration flow else // Otherwise, proceed with the login / registration flow
{ {
var nonce = Guid.NewGuid().ToString(); var nonce = Guid.NewGuid().ToString();
var state = Guid.NewGuid().ToString();
// The state parameter is the returnUrl. The callback will not find a session state and will treat it as a login. // Create login state with return URL and device ID
var authUrl = oidcService.GetAuthorizationUrl(returnUrl ?? "/", nonce); var oidcState = OidcState.ForLogin(returnUrl ?? "/", deviceId);
await cache.SetAsync($"{StateCachePrefix}{state}", oidcState, StateExpiration);
var authUrl = oidcService.GetAuthorizationUrl(state, nonce);
return Redirect(authUrl); return Redirect(authUrl);
} }
} }
@ -62,7 +68,7 @@ public class OidcController(
/// Handles Apple authentication directly from mobile apps /// Handles Apple authentication directly from mobile apps
/// </summary> /// </summary>
[HttpPost("apple/mobile")] [HttpPost("apple/mobile")]
public async Task<ActionResult<AuthController.TokenExchangeResponse>> AppleMobileSignIn( public async Task<ActionResult<Challenge>> AppleMobileLogin(
[FromBody] AppleMobileSignInRequest request) [FromBody] AppleMobileSignInRequest request)
{ {
try try
@ -85,17 +91,14 @@ public class OidcController(
var account = await FindOrCreateAccount(userInfo, "apple"); var account = await FindOrCreateAccount(userInfo, "apple");
// Create session using the OIDC service // Create session using the OIDC service
var session = await appleService.CreateSessionForUserAsync( var challenge = await appleService.CreateChallengeForUserAsync(
userInfo, userInfo,
account, account,
HttpContext, HttpContext,
request.DeviceId request.DeviceId
); );
// Generate token using existing auth service return Ok(challenge);
var token = auth.CreateToken(session);
return Ok(new AuthController.TokenExchangeResponse { Token = token });
} }
catch (SecurityTokenValidationException ex) catch (SecurityTokenValidationException ex)
{ {

View File

@ -16,6 +16,7 @@ public abstract class OidcService(
IConfiguration configuration, IConfiguration configuration,
IHttpClientFactory httpClientFactory, IHttpClientFactory httpClientFactory,
AppDatabase db, AppDatabase db,
AuthService auth,
ICacheService cache ICacheService cache
) )
{ {
@ -187,7 +188,7 @@ public abstract class OidcService(
/// Creates a challenge and session for an authenticated user /// Creates a challenge and session for an authenticated user
/// Also creates or updates the account connection /// Also creates or updates the account connection
/// </summary> /// </summary>
public async Task<Session> CreateSessionForUserAsync( public async Task<Challenge> CreateChallengeForUserAsync(
OidcUserInfo userInfo, OidcUserInfo userInfo,
Account.Account account, Account.Account account,
HttpContext request, HttpContext request,
@ -220,8 +221,7 @@ public abstract class OidcService(
var challenge = new Challenge var challenge = new Challenge
{ {
ExpiredAt = now.Plus(Duration.FromHours(1)), ExpiredAt = now.Plus(Duration.FromHours(1)),
StepTotal = 1, StepTotal = await auth.DetectChallengeRisk(request.Request, account),
StepRemain = 0, // Already verified by provider
Type = ChallengeType.Oidc, Type = ChallengeType.Oidc,
Platform = ChallengePlatform.Unidentified, Platform = ChallengePlatform.Unidentified,
Audiences = [ProviderName], Audiences = [ProviderName],
@ -231,22 +231,13 @@ public abstract class OidcService(
IpAddress = request.Connection.RemoteIpAddress?.ToString() ?? null, IpAddress = request.Connection.RemoteIpAddress?.ToString() ?? null,
UserAgent = request.Request.Headers.UserAgent, UserAgent = request.Request.Headers.UserAgent,
}; };
challenge.StepRemain--;
if (challenge.StepRemain < 0) challenge.StepRemain = 0;
await Db.AuthChallenges.AddAsync(challenge); await Db.AuthChallenges.AddAsync(challenge);
// Create a session
var session = new Session
{
AccountId = account.Id,
CreatedAt = now,
LastGrantedAt = now,
Challenge = challenge
};
await Db.AuthSessions.AddAsync(session);
await Db.SaveChangesAsync(); await Db.SaveChangesAsync();
return session; return challenge;
} }
} }

View File

@ -0,0 +1,189 @@
using System.Text.Json;
using System.Text.Json.Serialization;
namespace DysonNetwork.Sphere.Auth.OpenId;
/// <summary>
/// Represents the state parameter used in OpenID Connect flows.
/// Handles serialization and deserialization of the state parameter.
/// </summary>
public class OidcState
{
/// <summary>
/// The type of OIDC flow (login or connect).
/// </summary>
public OidcFlowType FlowType { get; set; }
/// <summary>
/// The account ID (for connect flow).
/// </summary>
public Guid? AccountId { get; set; }
/// <summary>
/// The OIDC provider name.
/// </summary>
public string? Provider { get; set; }
/// <summary>
/// The nonce for CSRF protection.
/// </summary>
public string? Nonce { get; set; }
/// <summary>
/// The device ID for the authentication request.
/// </summary>
public string? DeviceId { get; set; }
/// <summary>
/// The return URL after authentication (for login flow).
/// </summary>
public string? ReturnUrl { get; set; }
/// <summary>
/// Creates a new OidcState for a connection flow.
/// </summary>
public static OidcState ForConnection(Guid accountId, string provider, string nonce, string? deviceId = null)
{
return new OidcState
{
FlowType = OidcFlowType.Connect,
AccountId = accountId,
Provider = provider,
Nonce = nonce,
DeviceId = deviceId
};
}
/// <summary>
/// Creates a new OidcState for a login flow.
/// </summary>
public static OidcState ForLogin(string returnUrl = "/", string? deviceId = null)
{
return new OidcState
{
FlowType = OidcFlowType.Login,
ReturnUrl = returnUrl,
DeviceId = deviceId
};
}
/// <summary>
/// The version of the state format.
/// </summary>
public int Version { get; set; } = 1;
/// <summary>
/// Serializes the state to a JSON string for use in OIDC flows.
/// </summary>
public string Serialize()
{
return JsonSerializer.Serialize(this, new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
});
}
/// <summary>
/// Attempts to parse a state string into an OidcState object.
/// </summary>
public static bool TryParse(string? stateString, out OidcState? state)
{
state = null;
if (string.IsNullOrEmpty(stateString))
return false;
try
{
// First try to parse as JSON
try
{
state = JsonSerializer.Deserialize<OidcState>(stateString);
return state != null;
}
catch (JsonException)
{
// Not a JSON string, try legacy format for backward compatibility
return TryParseLegacyFormat(stateString, out state);
}
}
catch
{
return false;
}
}
private static bool TryParseLegacyFormat(string stateString, out OidcState? state)
{
state = null;
var parts = stateString.Split('|');
// Check for connection flow format: {accountId}|{provider}|{nonce}|{deviceId}|connect
if (parts.Length >= 5 &&
Guid.TryParse(parts[0], out var accountId) &&
string.Equals(parts[^1], "connect", StringComparison.OrdinalIgnoreCase))
{
state = new OidcState
{
FlowType = OidcFlowType.Connect,
AccountId = accountId,
Provider = parts[1],
Nonce = parts[2],
DeviceId = parts.Length >= 4 && !string.IsNullOrEmpty(parts[3]) ? parts[3] : null
};
return true;
}
// Check for login flow format: {returnUrl}|{deviceId}|login
if (parts.Length >= 2 &&
parts.Length <= 3 &&
(parts.Length < 3 || string.Equals(parts[^1], "login", StringComparison.OrdinalIgnoreCase)))
{
state = new OidcState
{
FlowType = OidcFlowType.Login,
ReturnUrl = parts[0],
DeviceId = parts.Length >= 2 && !string.IsNullOrEmpty(parts[1]) ? parts[1] : null
};
return true;
}
// Legacy format support (for backward compatibility)
if (parts.Length == 1)
{
state = new OidcState
{
FlowType = OidcFlowType.Login,
ReturnUrl = parts[0],
DeviceId = null
};
return true;
}
return false;
}
}
/// <summary>
/// Represents the type of OIDC flow.
/// </summary>
public enum OidcFlowType
{
/// <summary>
/// Login or registration flow.
/// </summary>
Login,
/// <summary>
/// Account connection flow.
/// </summary>
Connect
}