diff --git a/DysonNetwork.Pass/Auth/OidcProvider/Services/OidcProviderService.cs b/DysonNetwork.Pass/Auth/OidcProvider/Services/OidcProviderService.cs index b3f720a..6dd764d 100644 --- a/DysonNetwork.Pass/Auth/OidcProvider/Services/OidcProviderService.cs +++ b/DysonNetwork.Pass/Auth/OidcProvider/Services/OidcProviderService.cs @@ -27,19 +27,47 @@ public class OidcProviderService( { private readonly OidcProviderOptions _options = options.Value; + private const string CacheKeyPrefixClientId = "auth:oidc-client:id:"; + private const string CacheKeyPrefixClientSlug = "auth:oidc-client:slug:"; + private const string CacheKeyPrefixAuthCode = "auth:oidc-code:"; + private const string CodeChallengeMethodS256 = "S256"; + private const string CodeChallengeMethodPlain = "PLAIN"; + public async Task FindClientByIdAsync(Guid clientId) { + var cacheKey = $"{CacheKeyPrefixClientId}{clientId}"; + var (found, cachedApp) = await cache.GetAsyncWithStatus(cacheKey); + if (found && cachedApp != null) + { + return cachedApp; + } + var resp = await customApps.GetCustomAppAsync(new GetCustomAppRequest { Id = clientId.ToString() }); + if (resp.App != null) + { + await cache.SetAsync(cacheKey, resp.App, TimeSpan.FromMinutes(5)); + } return resp.App ?? null; } public async Task FindClientBySlugAsync(string slug) { + var cacheKey = $"{CacheKeyPrefixClientSlug}{slug}"; + var (found, cachedApp) = await cache.GetAsyncWithStatus(cacheKey); + if (found && cachedApp != null) + { + return cachedApp; + } + var resp = await customApps.GetCustomAppAsync(new GetCustomAppRequest { Slug = slug }); + if (resp.App != null) + { + await cache.SetAsync(cacheKey, resp.App, TimeSpan.FromMinutes(5)); + } return resp.App ?? null; } - public async Task FindValidSessionAsync(Guid accountId, Guid clientId, bool withAccount = false) + private async Task FindValidSessionAsync(Guid accountId, Guid clientId, bool withAccount = false) { var now = SystemClock.Instance.GetCurrentInstant(); @@ -74,70 +102,79 @@ public class OidcProviderService( return resp.Valid; } + private static bool IsWildcardRedirectUriMatch(string allowedUri, string redirectUri) + { + if (string.IsNullOrEmpty(allowedUri) || string.IsNullOrEmpty(redirectUri)) + return false; + + // Check if it's an exact match + if (string.Equals(allowedUri, redirectUri, StringComparison.Ordinal)) + return true; + + // Quick check for wildcard patterns + if (!allowedUri.Contains('*')) + return false; + + // Parse URIs once + Uri? allowedUriObj, redirectUriObj; + try + { + allowedUriObj = new Uri(allowedUri); + redirectUriObj = new Uri(redirectUri); + } + catch (UriFormatException) + { + return false; + } + + // Check scheme and port matches + if (allowedUriObj.Scheme != redirectUriObj.Scheme || allowedUriObj.Port != redirectUriObj.Port) + { + return false; + } + + var allowedHost = allowedUriObj.Host; + var redirectHost = redirectUriObj.Host; + + // Handle wildcard domain patterns like *.example.com + if (allowedHost.StartsWith("*.")) + { + var baseDomain = allowedHost[2..]; // Remove "*." + if (redirectHost == baseDomain || redirectHost.EndsWith("." + baseDomain)) + { + // Check path match + var allowedPath = allowedUriObj.AbsolutePath.TrimEnd('/'); + var redirectPath = redirectUriObj.AbsolutePath.TrimEnd('/'); + + // If allowed path is empty, any path is allowed + // If allowed path is specified, redirect path must start with it + return string.IsNullOrEmpty(allowedPath) || + redirectPath.StartsWith(allowedPath, StringComparison.OrdinalIgnoreCase); + } + } + + return false; + } + public async Task ValidateRedirectUriAsync(Guid clientId, string redirectUri) { if (string.IsNullOrEmpty(redirectUri)) return false; - var client = await FindClientByIdAsync(clientId); if (client?.Status != Shared.Proto.CustomAppStatus.Production) return true; - if (client?.OauthConfig?.RedirectUris == null) + var redirectUris = client?.OauthConfig?.RedirectUris; + if (redirectUris == null || redirectUris.Count == 0) return false; - // Check if the redirect URI matches any of the allowed URIs - // For exact match - if (client.OauthConfig.RedirectUris.Contains(redirectUri)) - return true; - - // Check for wildcard matches (e.g., https://*.example.com/*) - foreach (var allowedUri in client.OauthConfig.RedirectUris) + // Check each allowed URI for a match + foreach (var allowedUri in redirectUris) { - if (string.IsNullOrEmpty(allowedUri)) - continue; - - // Handle wildcard in domain - if (allowedUri.Contains("*.") && allowedUri.StartsWith("http")) + if (IsWildcardRedirectUriMatch(allowedUri, redirectUri)) { - try - { - var allowedUriObj = new Uri(allowedUri); - var redirectUriObj = new Uri(redirectUri); - - if (allowedUriObj.Scheme != redirectUriObj.Scheme || - allowedUriObj.Port != redirectUriObj.Port) - { - continue; - } - - // Check if the domain matches the wildcard pattern - var allowedDomain = allowedUriObj.Host; - var redirectDomain = redirectUriObj.Host; - - if (allowedDomain.StartsWith("*.")) - { - var baseDomain = allowedDomain[2..]; // Remove the "*." prefix - if (redirectDomain == baseDomain || redirectDomain.EndsWith($".{baseDomain}")) - { - // Check path - var allowedPath = allowedUriObj.AbsolutePath.TrimEnd('/'); - var redirectPath = redirectUriObj.AbsolutePath.TrimEnd('/'); - - if (string.IsNullOrEmpty(allowedPath) || - redirectPath.StartsWith(allowedPath, StringComparison.OrdinalIgnoreCase)) - { - return true; - } - } - } - } - catch (UriFormatException) - { - // Invalid URI format in allowed URIs, skip - continue; - } + return true; } } @@ -219,6 +256,53 @@ public class OidcProviderService( return tokenHandler.WriteToken(token); } + private async Task<(SnAuthSession session, string? nonce, List? scopes)> HandleAuthorizationCodeFlowAsync( + string authorizationCode, + Guid clientId, + string? redirectUri, + string? codeVerifier + ) + { + var authCode = await ValidateAuthorizationCodeAsync(authorizationCode, clientId, redirectUri, codeVerifier); + if (authCode == null) + throw new InvalidOperationException("Invalid authorization code"); + + // Load the session for the user + var existingSession = await FindValidSessionAsync(authCode.AccountId, clientId, withAccount: true); + + SnAuthSession session; + if (existingSession == null) + { + var account = await db.Accounts + .Where(a => a.Id == authCode.AccountId) + .Include(a => a.Profile) + .Include(a => a.Contacts) + .FirstOrDefaultAsync(); + if (account == null) throw new InvalidOperationException("Account not found"); + session = await auth.CreateSessionForOidcAsync(account, SystemClock.Instance.GetCurrentInstant(), clientId); + session.Account = account; + } + else + { + session = existingSession; + } + + return (session, authCode.Nonce, authCode.Scopes); + } + + private async Task<(SnAuthSession session, string? nonce, List? scopes)> HandleRefreshTokenFlowAsync(Guid sessionId) + { + var session = await FindSessionByIdAsync(sessionId) ?? + throw new InvalidOperationException("Session not found"); + + // Verify the session is still valid + var now = SystemClock.Instance.GetCurrentInstant(); + if (session.ExpiredAt.HasValue && session.ExpiredAt < now) + throw new InvalidOperationException("Session has expired"); + + return (session, null, null); + } + public async Task GenerateTokenResponseAsync( Guid clientId, string? authorizationCode = null, @@ -227,58 +311,18 @@ public class OidcProviderService( Guid? sessionId = null ) { + if (clientId == Guid.Empty) throw new ArgumentException("Client ID cannot be empty", nameof(clientId)); + var client = await FindClientByIdAsync(clientId) ?? throw new InvalidOperationException("Client not found"); - SnAuthSession session; + var (session, nonce, scopes) = authorizationCode != null + ? await HandleAuthorizationCodeFlowAsync(authorizationCode, clientId, redirectUri, codeVerifier) + : sessionId.HasValue + ? await HandleRefreshTokenFlowAsync(sessionId.Value) + : throw new InvalidOperationException("Either authorization code or session ID must be provided"); + var clock = SystemClock.Instance; var now = clock.GetCurrentInstant(); - string? nonce = null; - List? scopes = null; - - if (authorizationCode != null) - { - // Authorization code flow - var authCode = await ValidateAuthorizationCodeAsync(authorizationCode, clientId, redirectUri, codeVerifier); - if (authCode == null) - throw new InvalidOperationException("Invalid authorization code"); - - // Load the session for the user - var existingSession = await FindValidSessionAsync(authCode.AccountId, clientId, withAccount: true); - - if (existingSession is null) - { - var account = await db.Accounts - .Where(a => a.Id == authCode.AccountId) - .Include(a => a.Profile) - .Include(a => a.Contacts) - .FirstOrDefaultAsync(); - if (account is null) throw new InvalidOperationException("Account not found"); - session = await auth.CreateSessionForOidcAsync(account, clock.GetCurrentInstant(), clientId); - session.Account = account; - } - else - { - session = existingSession; - } - - scopes = authCode.Scopes; - nonce = authCode.Nonce; - } - else if (sessionId.HasValue) - { - // Refresh token flow - session = await FindSessionByIdAsync(sessionId.Value) ?? - throw new InvalidOperationException("Session not found"); - - // Verify the session is still valid - if (session.ExpiredAt < now) - throw new InvalidOperationException("Session has expired"); - } - else - { - throw new InvalidOperationException("Either authorization code or session ID must be provided"); - } - var expiresIn = (int)_options.AccessTokenLifetime.TotalSeconds; var expiresAt = now.Plus(Duration.FromSeconds(expiresIn)); @@ -415,7 +459,7 @@ public class OidcProviderService( }; // Store the code with its metadata in the cache - var cacheKey = $"auth:oidc-code:{code}"; + var cacheKey = $"{CacheKeyPrefixAuthCode}{code}"; await cache.SetAsync(cacheKey, authCodeInfo, _options.AuthorizationCodeLifetime); logger.LogInformation("Generated authorization code for client {ClientId} and user {UserId}", clientId, userId); @@ -429,7 +473,7 @@ public class OidcProviderService( string? codeVerifier = null ) { - var cacheKey = $"auth:oidc-code:{code}"; + var cacheKey = $"{CacheKeyPrefixAuthCode}{code}"; var (found, authCode) = await cache.GetAsyncWithStatus(cacheKey); if (!found || authCode == null) @@ -465,8 +509,8 @@ public class OidcProviderService( var isValid = authCode.CodeChallengeMethod?.ToUpperInvariant() switch { - "S256" => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, "S256"), - "PLAIN" => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, "PLAIN"), + CodeChallengeMethodS256 => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, CodeChallengeMethodS256), + CodeChallengeMethodPlain => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, CodeChallengeMethodPlain), _ => false // Unsupported code challenge method }; @@ -504,19 +548,12 @@ public class OidcProviderService( { if (string.IsNullOrEmpty(codeVerifier)) return false; - if (method == "S256") - { - using var sha256 = SHA256.Create(); - var hash = sha256.ComputeHash(Encoding.UTF8.GetBytes(codeVerifier)); - var base64 = Base64UrlEncoder.Encode(hash); - return string.Equals(base64, codeChallenge, StringComparison.Ordinal); - } + if (method != CodeChallengeMethodS256) + return method == CodeChallengeMethodPlain && + string.Equals(codeVerifier, codeChallenge, StringComparison.Ordinal); + var hash = SHA256.HashData(Encoding.UTF8.GetBytes(codeVerifier)); + var base64 = Base64UrlEncoder.Encode(hash); - if (method == "PLAIN") - { - return string.Equals(codeVerifier, codeChallenge, StringComparison.Ordinal); - } - - return false; + return string.Equals(base64, codeChallenge, StringComparison.Ordinal); } }