diff --git a/DysonNetwork.Sphere/Auth/OpenId/AppleOidcService.cs b/DysonNetwork.Sphere/Auth/OpenId/AppleOidcService.cs index e8a782d..75420b8 100644 --- a/DysonNetwork.Sphere/Auth/OpenId/AppleOidcService.cs +++ b/DysonNetwork.Sphere/Auth/OpenId/AppleOidcService.cs @@ -15,9 +15,10 @@ public class AppleOidcService( IConfiguration configuration, IHttpClientFactory httpClientFactory, AppDatabase db, + AuthService auth, ICacheService cache ) - : OidcService(configuration, httpClientFactory, db, cache) + : OidcService(configuration, httpClientFactory, db, auth, cache) { private readonly IConfiguration _configuration = configuration; private readonly IHttpClientFactory _httpClientFactory = httpClientFactory; diff --git a/DysonNetwork.Sphere/Auth/OpenId/ConnectionController.cs b/DysonNetwork.Sphere/Auth/OpenId/ConnectionController.cs index e18ef6e..c3f6257 100644 --- a/DysonNetwork.Sphere/Auth/OpenId/ConnectionController.cs +++ b/DysonNetwork.Sphere/Auth/OpenId/ConnectionController.cs @@ -176,23 +176,52 @@ public class ConnectionController( if (callbackData.State == null) return BadRequest("State parameter is missing."); - // Get and validate state from cache + // Get the state from the cache var stateKey = $"{StateCachePrefix}{callbackData.State}"; - var stateValue = await cache.GetAsync(stateKey); - if (string.IsNullOrEmpty(stateValue)) - return BadRequest("Invalid or expired state parameter"); - - // Remove state from cache to prevent replay attacks + + // Try to get the state as OidcState first (new format) + var oidcState = await cache.GetAsync(stateKey); + + // If not found, try to get as string (legacy format) + if (oidcState == null) + { + var stateValue = await cache.GetAsync(stateKey); + if (string.IsNullOrEmpty(stateValue) || !OidcState.TryParse(stateValue, out oidcState) || oidcState == null) + return BadRequest("Invalid or expired state parameter"); + } + + // Remove the state from cache to prevent replay attacks await cache.RemoveAsync(stateKey); - var stateParts = stateValue.Split('|'); - if (stateParts.Length != 3) + // Handle the flow based on state type + if (oidcState.FlowType == OidcFlowType.Connect && oidcState.AccountId.HasValue) { - return BadRequest("Invalid state format"); + // Connection flow + if (oidcState.DeviceId != null) + { + callbackData.State = oidcState.DeviceId; + } + return await HandleManualConnection(provider, oidcService, callbackData, oidcState.AccountId.Value); + } + else if (oidcState.FlowType == OidcFlowType.Login) + { + // Login/Registration flow + if (!string.IsNullOrEmpty(oidcState.DeviceId)) + { + callbackData.State = oidcState.DeviceId; + } + + // Store return URL if provided + if (!string.IsNullOrEmpty(oidcState.ReturnUrl) && oidcState.ReturnUrl != "/") + { + var returnUrlKey = $"{ReturnUrlCachePrefix}{callbackData.State}"; + await cache.SetAsync(returnUrlKey, oidcState.ReturnUrl, StateExpiration); + } + + return await HandleLoginOrRegistration(provider, oidcService, callbackData); } - var accountId = Guid.Parse(stateParts[0]); - return await HandleManualConnection(provider, oidcService, callbackData, accountId); + return BadRequest("Unsupported flow type"); } private async Task HandleManualConnection( @@ -219,6 +248,9 @@ public class ConnectionController( return BadRequest($"{provider} did not return a valid user identifier."); } + // Extract device ID from the callback state if available + var deviceId = !string.IsNullOrEmpty(callbackData.State) ? callbackData.State : string.Empty; + // Check if this provider account is already connected to any user var existingConnection = await db.AccountConnections .FirstOrDefaultAsync(c => @@ -314,9 +346,16 @@ public class ConnectionController( if (connection != null) { // Login existing user - var session = await auth.CreateSessionAsync(connection.Account, clock.GetCurrentInstant()); - var token = auth.CreateToken(session); - return Redirect($"/auth/token?token={token}"); + var deviceId = !string.IsNullOrEmpty(callbackData.State) ? + callbackData.State.Split('|').FirstOrDefault() : + string.Empty; + + var challenge = await oidcService.CreateChallengeForUserAsync( + userInfo, + connection.Account, + HttpContext, + deviceId ?? string.Empty); + return Redirect($"/auth/callback?context={challenge.Id}"); } // Register new user diff --git a/DysonNetwork.Sphere/Auth/OpenId/DiscordOidcService.cs b/DysonNetwork.Sphere/Auth/OpenId/DiscordOidcService.cs index 6c7525f..ebf6d79 100644 --- a/DysonNetwork.Sphere/Auth/OpenId/DiscordOidcService.cs +++ b/DysonNetwork.Sphere/Auth/OpenId/DiscordOidcService.cs @@ -8,9 +8,10 @@ public class DiscordOidcService( IConfiguration configuration, IHttpClientFactory httpClientFactory, AppDatabase db, + AuthService auth, ICacheService cache ) - : OidcService(configuration, httpClientFactory, db, cache) + : OidcService(configuration, httpClientFactory, db, auth, cache) { public override string ProviderName => "Discord"; protected override string DiscoveryEndpoint => ""; // Discord doesn't have a standard OIDC discovery endpoint diff --git a/DysonNetwork.Sphere/Auth/OpenId/GitHubOidcService.cs b/DysonNetwork.Sphere/Auth/OpenId/GitHubOidcService.cs index 0bd753a..fc80bfe 100644 --- a/DysonNetwork.Sphere/Auth/OpenId/GitHubOidcService.cs +++ b/DysonNetwork.Sphere/Auth/OpenId/GitHubOidcService.cs @@ -8,9 +8,10 @@ public class GitHubOidcService( IConfiguration configuration, IHttpClientFactory httpClientFactory, AppDatabase db, + AuthService auth, ICacheService cache ) - : OidcService(configuration, httpClientFactory, db, cache) + : OidcService(configuration, httpClientFactory, db, auth, cache) { public override string ProviderName => "GitHub"; protected override string DiscoveryEndpoint => ""; // GitHub doesn't have a standard OIDC discovery endpoint diff --git a/DysonNetwork.Sphere/Auth/OpenId/GoogleOidcService.cs b/DysonNetwork.Sphere/Auth/OpenId/GoogleOidcService.cs index a909547..a446b2e 100644 --- a/DysonNetwork.Sphere/Auth/OpenId/GoogleOidcService.cs +++ b/DysonNetwork.Sphere/Auth/OpenId/GoogleOidcService.cs @@ -11,9 +11,10 @@ public class GoogleOidcService( IConfiguration configuration, IHttpClientFactory httpClientFactory, AppDatabase db, + AuthService auth, ICacheService cache ) - : OidcService(configuration, httpClientFactory, db, cache) + : OidcService(configuration, httpClientFactory, db, auth, cache) { private readonly IHttpClientFactory _httpClientFactory = httpClientFactory; @@ -31,8 +32,6 @@ public class GoogleOidcService( throw new InvalidOperationException("Authorization endpoint not found in discovery document"); } - // PKCE code removed: no code verifier/challenge generated - var queryParams = new Dictionary { { "client_id", config.ClientId }, diff --git a/DysonNetwork.Sphere/Auth/OpenId/MicrosoftOidcService.cs b/DysonNetwork.Sphere/Auth/OpenId/MicrosoftOidcService.cs index 71834ca..83efad1 100644 --- a/DysonNetwork.Sphere/Auth/OpenId/MicrosoftOidcService.cs +++ b/DysonNetwork.Sphere/Auth/OpenId/MicrosoftOidcService.cs @@ -8,9 +8,10 @@ public class MicrosoftOidcService( IConfiguration configuration, IHttpClientFactory httpClientFactory, AppDatabase db, + AuthService auth, ICacheService cache ) - : OidcService(configuration, httpClientFactory, db, cache) + : OidcService(configuration, httpClientFactory, db, auth, cache) { public override string ProviderName => "Microsoft"; diff --git a/DysonNetwork.Sphere/Auth/OpenId/OidcController.cs b/DysonNetwork.Sphere/Auth/OpenId/OidcController.cs index cf1c81c..696ff76 100644 --- a/DysonNetwork.Sphere/Auth/OpenId/OidcController.cs +++ b/DysonNetwork.Sphere/Auth/OpenId/OidcController.cs @@ -13,7 +13,6 @@ public class OidcController( IServiceProvider serviceProvider, AppDatabase db, AccountService accounts, - AuthService auth, ICacheService cache ) : ControllerBase @@ -22,7 +21,11 @@ public class OidcController( private static readonly TimeSpan StateExpiration = TimeSpan.FromMinutes(15); [HttpGet("{provider}")] - public async Task OidcLogin([FromRoute] string provider, [FromQuery] string? returnUrl = "/") + public async Task OidcLogin( + [FromRoute] string provider, + [FromQuery] string? returnUrl = "/", + [FromHeader(Name = "X-Device-Id")] string? deviceId = null + ) { try { @@ -34,9 +37,9 @@ public class OidcController( var state = Guid.NewGuid().ToString(); var nonce = Guid.NewGuid().ToString(); - // Store user's ID, provider, and nonce in cache. The callback will use this. - var stateValue = $"{currentUser.Id}|{provider}|{nonce}"; - await cache.SetAsync($"{StateCachePrefix}{state}", stateValue, StateExpiration); + // Create and store connection state + var oidcState = OidcState.ForConnection(currentUser.Id, provider, nonce, deviceId); + await cache.SetAsync($"{StateCachePrefix}{state}", oidcState, StateExpiration); // The state parameter sent to the provider is the GUID key for the cache. var authUrl = oidcService.GetAuthorizationUrl(state, nonce); @@ -45,9 +48,12 @@ public class OidcController( else // Otherwise, proceed with the login / registration flow { var nonce = Guid.NewGuid().ToString(); + var state = Guid.NewGuid().ToString(); - // The state parameter is the returnUrl. The callback will not find a session state and will treat it as a login. - var authUrl = oidcService.GetAuthorizationUrl(returnUrl ?? "/", nonce); + // Create login state with return URL and device ID + var oidcState = OidcState.ForLogin(returnUrl ?? "/", deviceId); + await cache.SetAsync($"{StateCachePrefix}{state}", oidcState, StateExpiration); + var authUrl = oidcService.GetAuthorizationUrl(state, nonce); return Redirect(authUrl); } } @@ -62,7 +68,7 @@ public class OidcController( /// Handles Apple authentication directly from mobile apps /// [HttpPost("apple/mobile")] - public async Task> AppleMobileSignIn( + public async Task> AppleMobileLogin( [FromBody] AppleMobileSignInRequest request) { try @@ -85,17 +91,14 @@ public class OidcController( var account = await FindOrCreateAccount(userInfo, "apple"); // Create session using the OIDC service - var session = await appleService.CreateSessionForUserAsync( + var challenge = await appleService.CreateChallengeForUserAsync( userInfo, account, HttpContext, request.DeviceId ); - // Generate token using existing auth service - var token = auth.CreateToken(session); - - return Ok(new AuthController.TokenExchangeResponse { Token = token }); + return Ok(challenge); } catch (SecurityTokenValidationException ex) { diff --git a/DysonNetwork.Sphere/Auth/OpenId/OidcService.cs b/DysonNetwork.Sphere/Auth/OpenId/OidcService.cs index 53ce96c..a4ddba7 100644 --- a/DysonNetwork.Sphere/Auth/OpenId/OidcService.cs +++ b/DysonNetwork.Sphere/Auth/OpenId/OidcService.cs @@ -16,6 +16,7 @@ public abstract class OidcService( IConfiguration configuration, IHttpClientFactory httpClientFactory, AppDatabase db, + AuthService auth, ICacheService cache ) { @@ -187,7 +188,7 @@ public abstract class OidcService( /// Creates a challenge and session for an authenticated user /// Also creates or updates the account connection /// - public async Task CreateSessionForUserAsync( + public async Task CreateChallengeForUserAsync( OidcUserInfo userInfo, Account.Account account, HttpContext request, @@ -220,8 +221,7 @@ public abstract class OidcService( var challenge = new Challenge { ExpiredAt = now.Plus(Duration.FromHours(1)), - StepTotal = 1, - StepRemain = 0, // Already verified by provider + StepTotal = await auth.DetectChallengeRisk(request.Request, account), Type = ChallengeType.Oidc, Platform = ChallengePlatform.Unidentified, Audiences = [ProviderName], @@ -231,22 +231,13 @@ public abstract class OidcService( IpAddress = request.Connection.RemoteIpAddress?.ToString() ?? null, UserAgent = request.Request.Headers.UserAgent, }; + challenge.StepRemain--; + if (challenge.StepRemain < 0) challenge.StepRemain = 0; await Db.AuthChallenges.AddAsync(challenge); - - // Create a session - var session = new Session - { - AccountId = account.Id, - CreatedAt = now, - LastGrantedAt = now, - Challenge = challenge - }; - - await Db.AuthSessions.AddAsync(session); await Db.SaveChangesAsync(); - return session; + return challenge; } } diff --git a/DysonNetwork.Sphere/Auth/OpenId/OidcState.cs b/DysonNetwork.Sphere/Auth/OpenId/OidcState.cs new file mode 100644 index 0000000..608956e --- /dev/null +++ b/DysonNetwork.Sphere/Auth/OpenId/OidcState.cs @@ -0,0 +1,189 @@ +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace DysonNetwork.Sphere.Auth.OpenId; + +/// +/// Represents the state parameter used in OpenID Connect flows. +/// Handles serialization and deserialization of the state parameter. +/// +public class OidcState +{ + /// + /// The type of OIDC flow (login or connect). + /// + public OidcFlowType FlowType { get; set; } + + /// + /// The account ID (for connect flow). + /// + public Guid? AccountId { get; set; } + + + /// + /// The OIDC provider name. + /// + public string? Provider { get; set; } + + + /// + /// The nonce for CSRF protection. + /// + public string? Nonce { get; set; } + + + /// + /// The device ID for the authentication request. + /// + public string? DeviceId { get; set; } + + + /// + /// The return URL after authentication (for login flow). + /// + public string? ReturnUrl { get; set; } + + + /// + /// Creates a new OidcState for a connection flow. + /// + public static OidcState ForConnection(Guid accountId, string provider, string nonce, string? deviceId = null) + { + return new OidcState + { + FlowType = OidcFlowType.Connect, + AccountId = accountId, + Provider = provider, + Nonce = nonce, + DeviceId = deviceId + }; + } + + /// + /// Creates a new OidcState for a login flow. + /// + public static OidcState ForLogin(string returnUrl = "/", string? deviceId = null) + { + return new OidcState + { + FlowType = OidcFlowType.Login, + ReturnUrl = returnUrl, + DeviceId = deviceId + }; + } + + /// + /// The version of the state format. + /// + public int Version { get; set; } = 1; + + /// + /// Serializes the state to a JSON string for use in OIDC flows. + /// + public string Serialize() + { + return JsonSerializer.Serialize(this, new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull + }); + } + + /// + /// Attempts to parse a state string into an OidcState object. + /// + public static bool TryParse(string? stateString, out OidcState? state) + { + state = null; + + if (string.IsNullOrEmpty(stateString)) + return false; + + try + { + // First try to parse as JSON + try + { + state = JsonSerializer.Deserialize(stateString); + return state != null; + } + catch (JsonException) + { + // Not a JSON string, try legacy format for backward compatibility + return TryParseLegacyFormat(stateString, out state); + } + } + catch + { + return false; + } + } + + private static bool TryParseLegacyFormat(string stateString, out OidcState? state) + { + state = null; + var parts = stateString.Split('|'); + + // Check for connection flow format: {accountId}|{provider}|{nonce}|{deviceId}|connect + if (parts.Length >= 5 && + Guid.TryParse(parts[0], out var accountId) && + string.Equals(parts[^1], "connect", StringComparison.OrdinalIgnoreCase)) + { + state = new OidcState + { + FlowType = OidcFlowType.Connect, + AccountId = accountId, + Provider = parts[1], + Nonce = parts[2], + DeviceId = parts.Length >= 4 && !string.IsNullOrEmpty(parts[3]) ? parts[3] : null + }; + return true; + } + + // Check for login flow format: {returnUrl}|{deviceId}|login + if (parts.Length >= 2 && + parts.Length <= 3 && + (parts.Length < 3 || string.Equals(parts[^1], "login", StringComparison.OrdinalIgnoreCase))) + { + state = new OidcState + { + FlowType = OidcFlowType.Login, + ReturnUrl = parts[0], + DeviceId = parts.Length >= 2 && !string.IsNullOrEmpty(parts[1]) ? parts[1] : null + }; + return true; + } + + // Legacy format support (for backward compatibility) + if (parts.Length == 1) + { + state = new OidcState + { + FlowType = OidcFlowType.Login, + ReturnUrl = parts[0], + DeviceId = null + }; + return true; + } + + + return false; + } +} + +/// +/// Represents the type of OIDC flow. +/// +public enum OidcFlowType +{ + /// + /// Login or registration flow. + /// + Login, + + + /// + /// Account connection flow. + /// + Connect +} \ No newline at end of file