💄 Optimize oidc provider
This commit is contained in:
@@ -27,19 +27,47 @@ public class OidcProviderService(
|
||||
{
|
||||
private readonly OidcProviderOptions _options = options.Value;
|
||||
|
||||
private const string CacheKeyPrefixClientId = "auth:oidc-client:id:";
|
||||
private const string CacheKeyPrefixClientSlug = "auth:oidc-client:slug:";
|
||||
private const string CacheKeyPrefixAuthCode = "auth:oidc-code:";
|
||||
private const string CodeChallengeMethodS256 = "S256";
|
||||
private const string CodeChallengeMethodPlain = "PLAIN";
|
||||
|
||||
public async Task<CustomApp?> FindClientByIdAsync(Guid clientId)
|
||||
{
|
||||
var cacheKey = $"{CacheKeyPrefixClientId}{clientId}";
|
||||
var (found, cachedApp) = await cache.GetAsyncWithStatus<CustomApp>(cacheKey);
|
||||
if (found && cachedApp != null)
|
||||
{
|
||||
return cachedApp;
|
||||
}
|
||||
|
||||
var resp = await customApps.GetCustomAppAsync(new GetCustomAppRequest { Id = clientId.ToString() });
|
||||
if (resp.App != null)
|
||||
{
|
||||
await cache.SetAsync(cacheKey, resp.App, TimeSpan.FromMinutes(5));
|
||||
}
|
||||
return resp.App ?? null;
|
||||
}
|
||||
|
||||
public async Task<CustomApp?> FindClientBySlugAsync(string slug)
|
||||
{
|
||||
var cacheKey = $"{CacheKeyPrefixClientSlug}{slug}";
|
||||
var (found, cachedApp) = await cache.GetAsyncWithStatus<CustomApp>(cacheKey);
|
||||
if (found && cachedApp != null)
|
||||
{
|
||||
return cachedApp;
|
||||
}
|
||||
|
||||
var resp = await customApps.GetCustomAppAsync(new GetCustomAppRequest { Slug = slug });
|
||||
if (resp.App != null)
|
||||
{
|
||||
await cache.SetAsync(cacheKey, resp.App, TimeSpan.FromMinutes(5));
|
||||
}
|
||||
return resp.App ?? null;
|
||||
}
|
||||
|
||||
public async Task<SnAuthSession?> FindValidSessionAsync(Guid accountId, Guid clientId, bool withAccount = false)
|
||||
private async Task<SnAuthSession?> FindValidSessionAsync(Guid accountId, Guid clientId, bool withAccount = false)
|
||||
{
|
||||
var now = SystemClock.Instance.GetCurrentInstant();
|
||||
|
||||
@@ -74,70 +102,79 @@ public class OidcProviderService(
|
||||
return resp.Valid;
|
||||
}
|
||||
|
||||
private static bool IsWildcardRedirectUriMatch(string allowedUri, string redirectUri)
|
||||
{
|
||||
if (string.IsNullOrEmpty(allowedUri) || string.IsNullOrEmpty(redirectUri))
|
||||
return false;
|
||||
|
||||
// Check if it's an exact match
|
||||
if (string.Equals(allowedUri, redirectUri, StringComparison.Ordinal))
|
||||
return true;
|
||||
|
||||
// Quick check for wildcard patterns
|
||||
if (!allowedUri.Contains('*'))
|
||||
return false;
|
||||
|
||||
// Parse URIs once
|
||||
Uri? allowedUriObj, redirectUriObj;
|
||||
try
|
||||
{
|
||||
allowedUriObj = new Uri(allowedUri);
|
||||
redirectUriObj = new Uri(redirectUri);
|
||||
}
|
||||
catch (UriFormatException)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check scheme and port matches
|
||||
if (allowedUriObj.Scheme != redirectUriObj.Scheme || allowedUriObj.Port != redirectUriObj.Port)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var allowedHost = allowedUriObj.Host;
|
||||
var redirectHost = redirectUriObj.Host;
|
||||
|
||||
// Handle wildcard domain patterns like *.example.com
|
||||
if (allowedHost.StartsWith("*."))
|
||||
{
|
||||
var baseDomain = allowedHost[2..]; // Remove "*."
|
||||
if (redirectHost == baseDomain || redirectHost.EndsWith("." + baseDomain))
|
||||
{
|
||||
// Check path match
|
||||
var allowedPath = allowedUriObj.AbsolutePath.TrimEnd('/');
|
||||
var redirectPath = redirectUriObj.AbsolutePath.TrimEnd('/');
|
||||
|
||||
// If allowed path is empty, any path is allowed
|
||||
// If allowed path is specified, redirect path must start with it
|
||||
return string.IsNullOrEmpty(allowedPath) ||
|
||||
redirectPath.StartsWith(allowedPath, StringComparison.OrdinalIgnoreCase);
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
public async Task<bool> ValidateRedirectUriAsync(Guid clientId, string redirectUri)
|
||||
{
|
||||
if (string.IsNullOrEmpty(redirectUri))
|
||||
return false;
|
||||
|
||||
|
||||
var client = await FindClientByIdAsync(clientId);
|
||||
if (client?.Status != Shared.Proto.CustomAppStatus.Production)
|
||||
return true;
|
||||
|
||||
if (client?.OauthConfig?.RedirectUris == null)
|
||||
var redirectUris = client?.OauthConfig?.RedirectUris;
|
||||
if (redirectUris == null || redirectUris.Count == 0)
|
||||
return false;
|
||||
|
||||
// Check if the redirect URI matches any of the allowed URIs
|
||||
// For exact match
|
||||
if (client.OauthConfig.RedirectUris.Contains(redirectUri))
|
||||
return true;
|
||||
|
||||
// Check for wildcard matches (e.g., https://*.example.com/*)
|
||||
foreach (var allowedUri in client.OauthConfig.RedirectUris)
|
||||
// Check each allowed URI for a match
|
||||
foreach (var allowedUri in redirectUris)
|
||||
{
|
||||
if (string.IsNullOrEmpty(allowedUri))
|
||||
continue;
|
||||
|
||||
// Handle wildcard in domain
|
||||
if (allowedUri.Contains("*.") && allowedUri.StartsWith("http"))
|
||||
if (IsWildcardRedirectUriMatch(allowedUri, redirectUri))
|
||||
{
|
||||
try
|
||||
{
|
||||
var allowedUriObj = new Uri(allowedUri);
|
||||
var redirectUriObj = new Uri(redirectUri);
|
||||
|
||||
if (allowedUriObj.Scheme != redirectUriObj.Scheme ||
|
||||
allowedUriObj.Port != redirectUriObj.Port)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if the domain matches the wildcard pattern
|
||||
var allowedDomain = allowedUriObj.Host;
|
||||
var redirectDomain = redirectUriObj.Host;
|
||||
|
||||
if (allowedDomain.StartsWith("*."))
|
||||
{
|
||||
var baseDomain = allowedDomain[2..]; // Remove the "*." prefix
|
||||
if (redirectDomain == baseDomain || redirectDomain.EndsWith($".{baseDomain}"))
|
||||
{
|
||||
// Check path
|
||||
var allowedPath = allowedUriObj.AbsolutePath.TrimEnd('/');
|
||||
var redirectPath = redirectUriObj.AbsolutePath.TrimEnd('/');
|
||||
|
||||
if (string.IsNullOrEmpty(allowedPath) ||
|
||||
redirectPath.StartsWith(allowedPath, StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (UriFormatException)
|
||||
{
|
||||
// Invalid URI format in allowed URIs, skip
|
||||
continue;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -219,6 +256,53 @@ public class OidcProviderService(
|
||||
return tokenHandler.WriteToken(token);
|
||||
}
|
||||
|
||||
private async Task<(SnAuthSession session, string? nonce, List<string>? scopes)> HandleAuthorizationCodeFlowAsync(
|
||||
string authorizationCode,
|
||||
Guid clientId,
|
||||
string? redirectUri,
|
||||
string? codeVerifier
|
||||
)
|
||||
{
|
||||
var authCode = await ValidateAuthorizationCodeAsync(authorizationCode, clientId, redirectUri, codeVerifier);
|
||||
if (authCode == null)
|
||||
throw new InvalidOperationException("Invalid authorization code");
|
||||
|
||||
// Load the session for the user
|
||||
var existingSession = await FindValidSessionAsync(authCode.AccountId, clientId, withAccount: true);
|
||||
|
||||
SnAuthSession session;
|
||||
if (existingSession == null)
|
||||
{
|
||||
var account = await db.Accounts
|
||||
.Where(a => a.Id == authCode.AccountId)
|
||||
.Include(a => a.Profile)
|
||||
.Include(a => a.Contacts)
|
||||
.FirstOrDefaultAsync();
|
||||
if (account == null) throw new InvalidOperationException("Account not found");
|
||||
session = await auth.CreateSessionForOidcAsync(account, SystemClock.Instance.GetCurrentInstant(), clientId);
|
||||
session.Account = account;
|
||||
}
|
||||
else
|
||||
{
|
||||
session = existingSession;
|
||||
}
|
||||
|
||||
return (session, authCode.Nonce, authCode.Scopes);
|
||||
}
|
||||
|
||||
private async Task<(SnAuthSession session, string? nonce, List<string>? scopes)> HandleRefreshTokenFlowAsync(Guid sessionId)
|
||||
{
|
||||
var session = await FindSessionByIdAsync(sessionId) ??
|
||||
throw new InvalidOperationException("Session not found");
|
||||
|
||||
// Verify the session is still valid
|
||||
var now = SystemClock.Instance.GetCurrentInstant();
|
||||
if (session.ExpiredAt.HasValue && session.ExpiredAt < now)
|
||||
throw new InvalidOperationException("Session has expired");
|
||||
|
||||
return (session, null, null);
|
||||
}
|
||||
|
||||
public async Task<TokenResponse> GenerateTokenResponseAsync(
|
||||
Guid clientId,
|
||||
string? authorizationCode = null,
|
||||
@@ -227,58 +311,18 @@ public class OidcProviderService(
|
||||
Guid? sessionId = null
|
||||
)
|
||||
{
|
||||
if (clientId == Guid.Empty) throw new ArgumentException("Client ID cannot be empty", nameof(clientId));
|
||||
|
||||
var client = await FindClientByIdAsync(clientId) ?? throw new InvalidOperationException("Client not found");
|
||||
|
||||
SnAuthSession session;
|
||||
var (session, nonce, scopes) = authorizationCode != null
|
||||
? await HandleAuthorizationCodeFlowAsync(authorizationCode, clientId, redirectUri, codeVerifier)
|
||||
: sessionId.HasValue
|
||||
? await HandleRefreshTokenFlowAsync(sessionId.Value)
|
||||
: throw new InvalidOperationException("Either authorization code or session ID must be provided");
|
||||
|
||||
var clock = SystemClock.Instance;
|
||||
var now = clock.GetCurrentInstant();
|
||||
string? nonce = null;
|
||||
List<string>? scopes = null;
|
||||
|
||||
if (authorizationCode != null)
|
||||
{
|
||||
// Authorization code flow
|
||||
var authCode = await ValidateAuthorizationCodeAsync(authorizationCode, clientId, redirectUri, codeVerifier);
|
||||
if (authCode == null)
|
||||
throw new InvalidOperationException("Invalid authorization code");
|
||||
|
||||
// Load the session for the user
|
||||
var existingSession = await FindValidSessionAsync(authCode.AccountId, clientId, withAccount: true);
|
||||
|
||||
if (existingSession is null)
|
||||
{
|
||||
var account = await db.Accounts
|
||||
.Where(a => a.Id == authCode.AccountId)
|
||||
.Include(a => a.Profile)
|
||||
.Include(a => a.Contacts)
|
||||
.FirstOrDefaultAsync();
|
||||
if (account is null) throw new InvalidOperationException("Account not found");
|
||||
session = await auth.CreateSessionForOidcAsync(account, clock.GetCurrentInstant(), clientId);
|
||||
session.Account = account;
|
||||
}
|
||||
else
|
||||
{
|
||||
session = existingSession;
|
||||
}
|
||||
|
||||
scopes = authCode.Scopes;
|
||||
nonce = authCode.Nonce;
|
||||
}
|
||||
else if (sessionId.HasValue)
|
||||
{
|
||||
// Refresh token flow
|
||||
session = await FindSessionByIdAsync(sessionId.Value) ??
|
||||
throw new InvalidOperationException("Session not found");
|
||||
|
||||
// Verify the session is still valid
|
||||
if (session.ExpiredAt < now)
|
||||
throw new InvalidOperationException("Session has expired");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("Either authorization code or session ID must be provided");
|
||||
}
|
||||
|
||||
var expiresIn = (int)_options.AccessTokenLifetime.TotalSeconds;
|
||||
var expiresAt = now.Plus(Duration.FromSeconds(expiresIn));
|
||||
|
||||
@@ -415,7 +459,7 @@ public class OidcProviderService(
|
||||
};
|
||||
|
||||
// Store the code with its metadata in the cache
|
||||
var cacheKey = $"auth:oidc-code:{code}";
|
||||
var cacheKey = $"{CacheKeyPrefixAuthCode}{code}";
|
||||
await cache.SetAsync(cacheKey, authCodeInfo, _options.AuthorizationCodeLifetime);
|
||||
|
||||
logger.LogInformation("Generated authorization code for client {ClientId} and user {UserId}", clientId, userId);
|
||||
@@ -429,7 +473,7 @@ public class OidcProviderService(
|
||||
string? codeVerifier = null
|
||||
)
|
||||
{
|
||||
var cacheKey = $"auth:oidc-code:{code}";
|
||||
var cacheKey = $"{CacheKeyPrefixAuthCode}{code}";
|
||||
var (found, authCode) = await cache.GetAsyncWithStatus<AuthorizationCodeInfo>(cacheKey);
|
||||
|
||||
if (!found || authCode == null)
|
||||
@@ -465,8 +509,8 @@ public class OidcProviderService(
|
||||
|
||||
var isValid = authCode.CodeChallengeMethod?.ToUpperInvariant() switch
|
||||
{
|
||||
"S256" => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, "S256"),
|
||||
"PLAIN" => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, "PLAIN"),
|
||||
CodeChallengeMethodS256 => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, CodeChallengeMethodS256),
|
||||
CodeChallengeMethodPlain => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, CodeChallengeMethodPlain),
|
||||
_ => false // Unsupported code challenge method
|
||||
};
|
||||
|
||||
@@ -504,19 +548,12 @@ public class OidcProviderService(
|
||||
{
|
||||
if (string.IsNullOrEmpty(codeVerifier)) return false;
|
||||
|
||||
if (method == "S256")
|
||||
{
|
||||
using var sha256 = SHA256.Create();
|
||||
var hash = sha256.ComputeHash(Encoding.UTF8.GetBytes(codeVerifier));
|
||||
var base64 = Base64UrlEncoder.Encode(hash);
|
||||
return string.Equals(base64, codeChallenge, StringComparison.Ordinal);
|
||||
}
|
||||
if (method != CodeChallengeMethodS256)
|
||||
return method == CodeChallengeMethodPlain &&
|
||||
string.Equals(codeVerifier, codeChallenge, StringComparison.Ordinal);
|
||||
var hash = SHA256.HashData(Encoding.UTF8.GetBytes(codeVerifier));
|
||||
var base64 = Base64UrlEncoder.Encode(hash);
|
||||
|
||||
if (method == "PLAIN")
|
||||
{
|
||||
return string.Equals(codeVerifier, codeChallenge, StringComparison.Ordinal);
|
||||
}
|
||||
|
||||
return false;
|
||||
return string.Equals(base64, codeChallenge, StringComparison.Ordinal);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user