💄 Optimize oidc provider

This commit is contained in:
2025-11-02 14:35:02 +08:00
parent 5445df3b61
commit 08f924f647

View File

@@ -27,19 +27,47 @@ public class OidcProviderService(
{
private readonly OidcProviderOptions _options = options.Value;
private const string CacheKeyPrefixClientId = "auth:oidc-client:id:";
private const string CacheKeyPrefixClientSlug = "auth:oidc-client:slug:";
private const string CacheKeyPrefixAuthCode = "auth:oidc-code:";
private const string CodeChallengeMethodS256 = "S256";
private const string CodeChallengeMethodPlain = "PLAIN";
public async Task<CustomApp?> FindClientByIdAsync(Guid clientId)
{
var cacheKey = $"{CacheKeyPrefixClientId}{clientId}";
var (found, cachedApp) = await cache.GetAsyncWithStatus<CustomApp>(cacheKey);
if (found && cachedApp != null)
{
return cachedApp;
}
var resp = await customApps.GetCustomAppAsync(new GetCustomAppRequest { Id = clientId.ToString() });
if (resp.App != null)
{
await cache.SetAsync(cacheKey, resp.App, TimeSpan.FromMinutes(5));
}
return resp.App ?? null;
}
public async Task<CustomApp?> FindClientBySlugAsync(string slug)
{
var cacheKey = $"{CacheKeyPrefixClientSlug}{slug}";
var (found, cachedApp) = await cache.GetAsyncWithStatus<CustomApp>(cacheKey);
if (found && cachedApp != null)
{
return cachedApp;
}
var resp = await customApps.GetCustomAppAsync(new GetCustomAppRequest { Slug = slug });
if (resp.App != null)
{
await cache.SetAsync(cacheKey, resp.App, TimeSpan.FromMinutes(5));
}
return resp.App ?? null;
}
public async Task<SnAuthSession?> FindValidSessionAsync(Guid accountId, Guid clientId, bool withAccount = false)
private async Task<SnAuthSession?> FindValidSessionAsync(Guid accountId, Guid clientId, bool withAccount = false)
{
var now = SystemClock.Instance.GetCurrentInstant();
@@ -74,70 +102,79 @@ public class OidcProviderService(
return resp.Valid;
}
private static bool IsWildcardRedirectUriMatch(string allowedUri, string redirectUri)
{
if (string.IsNullOrEmpty(allowedUri) || string.IsNullOrEmpty(redirectUri))
return false;
// Check if it's an exact match
if (string.Equals(allowedUri, redirectUri, StringComparison.Ordinal))
return true;
// Quick check for wildcard patterns
if (!allowedUri.Contains('*'))
return false;
// Parse URIs once
Uri? allowedUriObj, redirectUriObj;
try
{
allowedUriObj = new Uri(allowedUri);
redirectUriObj = new Uri(redirectUri);
}
catch (UriFormatException)
{
return false;
}
// Check scheme and port matches
if (allowedUriObj.Scheme != redirectUriObj.Scheme || allowedUriObj.Port != redirectUriObj.Port)
{
return false;
}
var allowedHost = allowedUriObj.Host;
var redirectHost = redirectUriObj.Host;
// Handle wildcard domain patterns like *.example.com
if (allowedHost.StartsWith("*."))
{
var baseDomain = allowedHost[2..]; // Remove "*."
if (redirectHost == baseDomain || redirectHost.EndsWith("." + baseDomain))
{
// Check path match
var allowedPath = allowedUriObj.AbsolutePath.TrimEnd('/');
var redirectPath = redirectUriObj.AbsolutePath.TrimEnd('/');
// If allowed path is empty, any path is allowed
// If allowed path is specified, redirect path must start with it
return string.IsNullOrEmpty(allowedPath) ||
redirectPath.StartsWith(allowedPath, StringComparison.OrdinalIgnoreCase);
}
}
return false;
}
public async Task<bool> ValidateRedirectUriAsync(Guid clientId, string redirectUri)
{
if (string.IsNullOrEmpty(redirectUri))
return false;
var client = await FindClientByIdAsync(clientId);
if (client?.Status != Shared.Proto.CustomAppStatus.Production)
return true;
if (client?.OauthConfig?.RedirectUris == null)
var redirectUris = client?.OauthConfig?.RedirectUris;
if (redirectUris == null || redirectUris.Count == 0)
return false;
// Check if the redirect URI matches any of the allowed URIs
// For exact match
if (client.OauthConfig.RedirectUris.Contains(redirectUri))
return true;
// Check for wildcard matches (e.g., https://*.example.com/*)
foreach (var allowedUri in client.OauthConfig.RedirectUris)
// Check each allowed URI for a match
foreach (var allowedUri in redirectUris)
{
if (string.IsNullOrEmpty(allowedUri))
continue;
// Handle wildcard in domain
if (allowedUri.Contains("*.") && allowedUri.StartsWith("http"))
if (IsWildcardRedirectUriMatch(allowedUri, redirectUri))
{
try
{
var allowedUriObj = new Uri(allowedUri);
var redirectUriObj = new Uri(redirectUri);
if (allowedUriObj.Scheme != redirectUriObj.Scheme ||
allowedUriObj.Port != redirectUriObj.Port)
{
continue;
}
// Check if the domain matches the wildcard pattern
var allowedDomain = allowedUriObj.Host;
var redirectDomain = redirectUriObj.Host;
if (allowedDomain.StartsWith("*."))
{
var baseDomain = allowedDomain[2..]; // Remove the "*." prefix
if (redirectDomain == baseDomain || redirectDomain.EndsWith($".{baseDomain}"))
{
// Check path
var allowedPath = allowedUriObj.AbsolutePath.TrimEnd('/');
var redirectPath = redirectUriObj.AbsolutePath.TrimEnd('/');
if (string.IsNullOrEmpty(allowedPath) ||
redirectPath.StartsWith(allowedPath, StringComparison.OrdinalIgnoreCase))
{
return true;
}
}
}
}
catch (UriFormatException)
{
// Invalid URI format in allowed URIs, skip
continue;
}
return true;
}
}
@@ -219,6 +256,53 @@ public class OidcProviderService(
return tokenHandler.WriteToken(token);
}
private async Task<(SnAuthSession session, string? nonce, List<string>? scopes)> HandleAuthorizationCodeFlowAsync(
string authorizationCode,
Guid clientId,
string? redirectUri,
string? codeVerifier
)
{
var authCode = await ValidateAuthorizationCodeAsync(authorizationCode, clientId, redirectUri, codeVerifier);
if (authCode == null)
throw new InvalidOperationException("Invalid authorization code");
// Load the session for the user
var existingSession = await FindValidSessionAsync(authCode.AccountId, clientId, withAccount: true);
SnAuthSession session;
if (existingSession == null)
{
var account = await db.Accounts
.Where(a => a.Id == authCode.AccountId)
.Include(a => a.Profile)
.Include(a => a.Contacts)
.FirstOrDefaultAsync();
if (account == null) throw new InvalidOperationException("Account not found");
session = await auth.CreateSessionForOidcAsync(account, SystemClock.Instance.GetCurrentInstant(), clientId);
session.Account = account;
}
else
{
session = existingSession;
}
return (session, authCode.Nonce, authCode.Scopes);
}
private async Task<(SnAuthSession session, string? nonce, List<string>? scopes)> HandleRefreshTokenFlowAsync(Guid sessionId)
{
var session = await FindSessionByIdAsync(sessionId) ??
throw new InvalidOperationException("Session not found");
// Verify the session is still valid
var now = SystemClock.Instance.GetCurrentInstant();
if (session.ExpiredAt.HasValue && session.ExpiredAt < now)
throw new InvalidOperationException("Session has expired");
return (session, null, null);
}
public async Task<TokenResponse> GenerateTokenResponseAsync(
Guid clientId,
string? authorizationCode = null,
@@ -227,58 +311,18 @@ public class OidcProviderService(
Guid? sessionId = null
)
{
if (clientId == Guid.Empty) throw new ArgumentException("Client ID cannot be empty", nameof(clientId));
var client = await FindClientByIdAsync(clientId) ?? throw new InvalidOperationException("Client not found");
SnAuthSession session;
var (session, nonce, scopes) = authorizationCode != null
? await HandleAuthorizationCodeFlowAsync(authorizationCode, clientId, redirectUri, codeVerifier)
: sessionId.HasValue
? await HandleRefreshTokenFlowAsync(sessionId.Value)
: throw new InvalidOperationException("Either authorization code or session ID must be provided");
var clock = SystemClock.Instance;
var now = clock.GetCurrentInstant();
string? nonce = null;
List<string>? scopes = null;
if (authorizationCode != null)
{
// Authorization code flow
var authCode = await ValidateAuthorizationCodeAsync(authorizationCode, clientId, redirectUri, codeVerifier);
if (authCode == null)
throw new InvalidOperationException("Invalid authorization code");
// Load the session for the user
var existingSession = await FindValidSessionAsync(authCode.AccountId, clientId, withAccount: true);
if (existingSession is null)
{
var account = await db.Accounts
.Where(a => a.Id == authCode.AccountId)
.Include(a => a.Profile)
.Include(a => a.Contacts)
.FirstOrDefaultAsync();
if (account is null) throw new InvalidOperationException("Account not found");
session = await auth.CreateSessionForOidcAsync(account, clock.GetCurrentInstant(), clientId);
session.Account = account;
}
else
{
session = existingSession;
}
scopes = authCode.Scopes;
nonce = authCode.Nonce;
}
else if (sessionId.HasValue)
{
// Refresh token flow
session = await FindSessionByIdAsync(sessionId.Value) ??
throw new InvalidOperationException("Session not found");
// Verify the session is still valid
if (session.ExpiredAt < now)
throw new InvalidOperationException("Session has expired");
}
else
{
throw new InvalidOperationException("Either authorization code or session ID must be provided");
}
var expiresIn = (int)_options.AccessTokenLifetime.TotalSeconds;
var expiresAt = now.Plus(Duration.FromSeconds(expiresIn));
@@ -415,7 +459,7 @@ public class OidcProviderService(
};
// Store the code with its metadata in the cache
var cacheKey = $"auth:oidc-code:{code}";
var cacheKey = $"{CacheKeyPrefixAuthCode}{code}";
await cache.SetAsync(cacheKey, authCodeInfo, _options.AuthorizationCodeLifetime);
logger.LogInformation("Generated authorization code for client {ClientId} and user {UserId}", clientId, userId);
@@ -429,7 +473,7 @@ public class OidcProviderService(
string? codeVerifier = null
)
{
var cacheKey = $"auth:oidc-code:{code}";
var cacheKey = $"{CacheKeyPrefixAuthCode}{code}";
var (found, authCode) = await cache.GetAsyncWithStatus<AuthorizationCodeInfo>(cacheKey);
if (!found || authCode == null)
@@ -465,8 +509,8 @@ public class OidcProviderService(
var isValid = authCode.CodeChallengeMethod?.ToUpperInvariant() switch
{
"S256" => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, "S256"),
"PLAIN" => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, "PLAIN"),
CodeChallengeMethodS256 => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, CodeChallengeMethodS256),
CodeChallengeMethodPlain => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, CodeChallengeMethodPlain),
_ => false // Unsupported code challenge method
};
@@ -504,19 +548,12 @@ public class OidcProviderService(
{
if (string.IsNullOrEmpty(codeVerifier)) return false;
if (method == "S256")
{
using var sha256 = SHA256.Create();
var hash = sha256.ComputeHash(Encoding.UTF8.GetBytes(codeVerifier));
var base64 = Base64UrlEncoder.Encode(hash);
return string.Equals(base64, codeChallenge, StringComparison.Ordinal);
}
if (method != CodeChallengeMethodS256)
return method == CodeChallengeMethodPlain &&
string.Equals(codeVerifier, codeChallenge, StringComparison.Ordinal);
var hash = SHA256.HashData(Encoding.UTF8.GetBytes(codeVerifier));
var base64 = Base64UrlEncoder.Encode(hash);
if (method == "PLAIN")
{
return string.Equals(codeVerifier, codeChallenge, StringComparison.Ordinal);
}
return false;
return string.Equals(base64, codeChallenge, StringComparison.Ordinal);
}
}