💄 Optimize oidc provider
This commit is contained in:
@@ -27,19 +27,47 @@ public class OidcProviderService(
|
|||||||
{
|
{
|
||||||
private readonly OidcProviderOptions _options = options.Value;
|
private readonly OidcProviderOptions _options = options.Value;
|
||||||
|
|
||||||
|
private const string CacheKeyPrefixClientId = "auth:oidc-client:id:";
|
||||||
|
private const string CacheKeyPrefixClientSlug = "auth:oidc-client:slug:";
|
||||||
|
private const string CacheKeyPrefixAuthCode = "auth:oidc-code:";
|
||||||
|
private const string CodeChallengeMethodS256 = "S256";
|
||||||
|
private const string CodeChallengeMethodPlain = "PLAIN";
|
||||||
|
|
||||||
public async Task<CustomApp?> FindClientByIdAsync(Guid clientId)
|
public async Task<CustomApp?> FindClientByIdAsync(Guid clientId)
|
||||||
{
|
{
|
||||||
|
var cacheKey = $"{CacheKeyPrefixClientId}{clientId}";
|
||||||
|
var (found, cachedApp) = await cache.GetAsyncWithStatus<CustomApp>(cacheKey);
|
||||||
|
if (found && cachedApp != null)
|
||||||
|
{
|
||||||
|
return cachedApp;
|
||||||
|
}
|
||||||
|
|
||||||
var resp = await customApps.GetCustomAppAsync(new GetCustomAppRequest { Id = clientId.ToString() });
|
var resp = await customApps.GetCustomAppAsync(new GetCustomAppRequest { Id = clientId.ToString() });
|
||||||
|
if (resp.App != null)
|
||||||
|
{
|
||||||
|
await cache.SetAsync(cacheKey, resp.App, TimeSpan.FromMinutes(5));
|
||||||
|
}
|
||||||
return resp.App ?? null;
|
return resp.App ?? null;
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<CustomApp?> FindClientBySlugAsync(string slug)
|
public async Task<CustomApp?> FindClientBySlugAsync(string slug)
|
||||||
{
|
{
|
||||||
|
var cacheKey = $"{CacheKeyPrefixClientSlug}{slug}";
|
||||||
|
var (found, cachedApp) = await cache.GetAsyncWithStatus<CustomApp>(cacheKey);
|
||||||
|
if (found && cachedApp != null)
|
||||||
|
{
|
||||||
|
return cachedApp;
|
||||||
|
}
|
||||||
|
|
||||||
var resp = await customApps.GetCustomAppAsync(new GetCustomAppRequest { Slug = slug });
|
var resp = await customApps.GetCustomAppAsync(new GetCustomAppRequest { Slug = slug });
|
||||||
|
if (resp.App != null)
|
||||||
|
{
|
||||||
|
await cache.SetAsync(cacheKey, resp.App, TimeSpan.FromMinutes(5));
|
||||||
|
}
|
||||||
return resp.App ?? null;
|
return resp.App ?? null;
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<SnAuthSession?> FindValidSessionAsync(Guid accountId, Guid clientId, bool withAccount = false)
|
private async Task<SnAuthSession?> FindValidSessionAsync(Guid accountId, Guid clientId, bool withAccount = false)
|
||||||
{
|
{
|
||||||
var now = SystemClock.Instance.GetCurrentInstant();
|
var now = SystemClock.Instance.GetCurrentInstant();
|
||||||
|
|
||||||
@@ -74,72 +102,81 @@ public class OidcProviderService(
|
|||||||
return resp.Valid;
|
return resp.Valid;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static bool IsWildcardRedirectUriMatch(string allowedUri, string redirectUri)
|
||||||
|
{
|
||||||
|
if (string.IsNullOrEmpty(allowedUri) || string.IsNullOrEmpty(redirectUri))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// Check if it's an exact match
|
||||||
|
if (string.Equals(allowedUri, redirectUri, StringComparison.Ordinal))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
// Quick check for wildcard patterns
|
||||||
|
if (!allowedUri.Contains('*'))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// Parse URIs once
|
||||||
|
Uri? allowedUriObj, redirectUriObj;
|
||||||
|
try
|
||||||
|
{
|
||||||
|
allowedUriObj = new Uri(allowedUri);
|
||||||
|
redirectUriObj = new Uri(redirectUri);
|
||||||
|
}
|
||||||
|
catch (UriFormatException)
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check scheme and port matches
|
||||||
|
if (allowedUriObj.Scheme != redirectUriObj.Scheme || allowedUriObj.Port != redirectUriObj.Port)
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
var allowedHost = allowedUriObj.Host;
|
||||||
|
var redirectHost = redirectUriObj.Host;
|
||||||
|
|
||||||
|
// Handle wildcard domain patterns like *.example.com
|
||||||
|
if (allowedHost.StartsWith("*."))
|
||||||
|
{
|
||||||
|
var baseDomain = allowedHost[2..]; // Remove "*."
|
||||||
|
if (redirectHost == baseDomain || redirectHost.EndsWith("." + baseDomain))
|
||||||
|
{
|
||||||
|
// Check path match
|
||||||
|
var allowedPath = allowedUriObj.AbsolutePath.TrimEnd('/');
|
||||||
|
var redirectPath = redirectUriObj.AbsolutePath.TrimEnd('/');
|
||||||
|
|
||||||
|
// If allowed path is empty, any path is allowed
|
||||||
|
// If allowed path is specified, redirect path must start with it
|
||||||
|
return string.IsNullOrEmpty(allowedPath) ||
|
||||||
|
redirectPath.StartsWith(allowedPath, StringComparison.OrdinalIgnoreCase);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
public async Task<bool> ValidateRedirectUriAsync(Guid clientId, string redirectUri)
|
public async Task<bool> ValidateRedirectUriAsync(Guid clientId, string redirectUri)
|
||||||
{
|
{
|
||||||
if (string.IsNullOrEmpty(redirectUri))
|
if (string.IsNullOrEmpty(redirectUri))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
|
|
||||||
var client = await FindClientByIdAsync(clientId);
|
var client = await FindClientByIdAsync(clientId);
|
||||||
if (client?.Status != Shared.Proto.CustomAppStatus.Production)
|
if (client?.Status != Shared.Proto.CustomAppStatus.Production)
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
if (client?.OauthConfig?.RedirectUris == null)
|
var redirectUris = client?.OauthConfig?.RedirectUris;
|
||||||
|
if (redirectUris == null || redirectUris.Count == 0)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// Check if the redirect URI matches any of the allowed URIs
|
// Check each allowed URI for a match
|
||||||
// For exact match
|
foreach (var allowedUri in redirectUris)
|
||||||
if (client.OauthConfig.RedirectUris.Contains(redirectUri))
|
|
||||||
return true;
|
|
||||||
|
|
||||||
// Check for wildcard matches (e.g., https://*.example.com/*)
|
|
||||||
foreach (var allowedUri in client.OauthConfig.RedirectUris)
|
|
||||||
{
|
{
|
||||||
if (string.IsNullOrEmpty(allowedUri))
|
if (IsWildcardRedirectUriMatch(allowedUri, redirectUri))
|
||||||
continue;
|
|
||||||
|
|
||||||
// Handle wildcard in domain
|
|
||||||
if (allowedUri.Contains("*.") && allowedUri.StartsWith("http"))
|
|
||||||
{
|
|
||||||
try
|
|
||||||
{
|
|
||||||
var allowedUriObj = new Uri(allowedUri);
|
|
||||||
var redirectUriObj = new Uri(redirectUri);
|
|
||||||
|
|
||||||
if (allowedUriObj.Scheme != redirectUriObj.Scheme ||
|
|
||||||
allowedUriObj.Port != redirectUriObj.Port)
|
|
||||||
{
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the domain matches the wildcard pattern
|
|
||||||
var allowedDomain = allowedUriObj.Host;
|
|
||||||
var redirectDomain = redirectUriObj.Host;
|
|
||||||
|
|
||||||
if (allowedDomain.StartsWith("*."))
|
|
||||||
{
|
|
||||||
var baseDomain = allowedDomain[2..]; // Remove the "*." prefix
|
|
||||||
if (redirectDomain == baseDomain || redirectDomain.EndsWith($".{baseDomain}"))
|
|
||||||
{
|
|
||||||
// Check path
|
|
||||||
var allowedPath = allowedUriObj.AbsolutePath.TrimEnd('/');
|
|
||||||
var redirectPath = redirectUriObj.AbsolutePath.TrimEnd('/');
|
|
||||||
|
|
||||||
if (string.IsNullOrEmpty(allowedPath) ||
|
|
||||||
redirectPath.StartsWith(allowedPath, StringComparison.OrdinalIgnoreCase))
|
|
||||||
{
|
{
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
catch (UriFormatException)
|
|
||||||
{
|
|
||||||
// Invalid URI format in allowed URIs, skip
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -219,6 +256,53 @@ public class OidcProviderService(
|
|||||||
return tokenHandler.WriteToken(token);
|
return tokenHandler.WriteToken(token);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private async Task<(SnAuthSession session, string? nonce, List<string>? scopes)> HandleAuthorizationCodeFlowAsync(
|
||||||
|
string authorizationCode,
|
||||||
|
Guid clientId,
|
||||||
|
string? redirectUri,
|
||||||
|
string? codeVerifier
|
||||||
|
)
|
||||||
|
{
|
||||||
|
var authCode = await ValidateAuthorizationCodeAsync(authorizationCode, clientId, redirectUri, codeVerifier);
|
||||||
|
if (authCode == null)
|
||||||
|
throw new InvalidOperationException("Invalid authorization code");
|
||||||
|
|
||||||
|
// Load the session for the user
|
||||||
|
var existingSession = await FindValidSessionAsync(authCode.AccountId, clientId, withAccount: true);
|
||||||
|
|
||||||
|
SnAuthSession session;
|
||||||
|
if (existingSession == null)
|
||||||
|
{
|
||||||
|
var account = await db.Accounts
|
||||||
|
.Where(a => a.Id == authCode.AccountId)
|
||||||
|
.Include(a => a.Profile)
|
||||||
|
.Include(a => a.Contacts)
|
||||||
|
.FirstOrDefaultAsync();
|
||||||
|
if (account == null) throw new InvalidOperationException("Account not found");
|
||||||
|
session = await auth.CreateSessionForOidcAsync(account, SystemClock.Instance.GetCurrentInstant(), clientId);
|
||||||
|
session.Account = account;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
session = existingSession;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (session, authCode.Nonce, authCode.Scopes);
|
||||||
|
}
|
||||||
|
|
||||||
|
private async Task<(SnAuthSession session, string? nonce, List<string>? scopes)> HandleRefreshTokenFlowAsync(Guid sessionId)
|
||||||
|
{
|
||||||
|
var session = await FindSessionByIdAsync(sessionId) ??
|
||||||
|
throw new InvalidOperationException("Session not found");
|
||||||
|
|
||||||
|
// Verify the session is still valid
|
||||||
|
var now = SystemClock.Instance.GetCurrentInstant();
|
||||||
|
if (session.ExpiredAt.HasValue && session.ExpiredAt < now)
|
||||||
|
throw new InvalidOperationException("Session has expired");
|
||||||
|
|
||||||
|
return (session, null, null);
|
||||||
|
}
|
||||||
|
|
||||||
public async Task<TokenResponse> GenerateTokenResponseAsync(
|
public async Task<TokenResponse> GenerateTokenResponseAsync(
|
||||||
Guid clientId,
|
Guid clientId,
|
||||||
string? authorizationCode = null,
|
string? authorizationCode = null,
|
||||||
@@ -227,58 +311,18 @@ public class OidcProviderService(
|
|||||||
Guid? sessionId = null
|
Guid? sessionId = null
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
|
if (clientId == Guid.Empty) throw new ArgumentException("Client ID cannot be empty", nameof(clientId));
|
||||||
|
|
||||||
var client = await FindClientByIdAsync(clientId) ?? throw new InvalidOperationException("Client not found");
|
var client = await FindClientByIdAsync(clientId) ?? throw new InvalidOperationException("Client not found");
|
||||||
|
|
||||||
SnAuthSession session;
|
var (session, nonce, scopes) = authorizationCode != null
|
||||||
|
? await HandleAuthorizationCodeFlowAsync(authorizationCode, clientId, redirectUri, codeVerifier)
|
||||||
|
: sessionId.HasValue
|
||||||
|
? await HandleRefreshTokenFlowAsync(sessionId.Value)
|
||||||
|
: throw new InvalidOperationException("Either authorization code or session ID must be provided");
|
||||||
|
|
||||||
var clock = SystemClock.Instance;
|
var clock = SystemClock.Instance;
|
||||||
var now = clock.GetCurrentInstant();
|
var now = clock.GetCurrentInstant();
|
||||||
string? nonce = null;
|
|
||||||
List<string>? scopes = null;
|
|
||||||
|
|
||||||
if (authorizationCode != null)
|
|
||||||
{
|
|
||||||
// Authorization code flow
|
|
||||||
var authCode = await ValidateAuthorizationCodeAsync(authorizationCode, clientId, redirectUri, codeVerifier);
|
|
||||||
if (authCode == null)
|
|
||||||
throw new InvalidOperationException("Invalid authorization code");
|
|
||||||
|
|
||||||
// Load the session for the user
|
|
||||||
var existingSession = await FindValidSessionAsync(authCode.AccountId, clientId, withAccount: true);
|
|
||||||
|
|
||||||
if (existingSession is null)
|
|
||||||
{
|
|
||||||
var account = await db.Accounts
|
|
||||||
.Where(a => a.Id == authCode.AccountId)
|
|
||||||
.Include(a => a.Profile)
|
|
||||||
.Include(a => a.Contacts)
|
|
||||||
.FirstOrDefaultAsync();
|
|
||||||
if (account is null) throw new InvalidOperationException("Account not found");
|
|
||||||
session = await auth.CreateSessionForOidcAsync(account, clock.GetCurrentInstant(), clientId);
|
|
||||||
session.Account = account;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
session = existingSession;
|
|
||||||
}
|
|
||||||
|
|
||||||
scopes = authCode.Scopes;
|
|
||||||
nonce = authCode.Nonce;
|
|
||||||
}
|
|
||||||
else if (sessionId.HasValue)
|
|
||||||
{
|
|
||||||
// Refresh token flow
|
|
||||||
session = await FindSessionByIdAsync(sessionId.Value) ??
|
|
||||||
throw new InvalidOperationException("Session not found");
|
|
||||||
|
|
||||||
// Verify the session is still valid
|
|
||||||
if (session.ExpiredAt < now)
|
|
||||||
throw new InvalidOperationException("Session has expired");
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
throw new InvalidOperationException("Either authorization code or session ID must be provided");
|
|
||||||
}
|
|
||||||
|
|
||||||
var expiresIn = (int)_options.AccessTokenLifetime.TotalSeconds;
|
var expiresIn = (int)_options.AccessTokenLifetime.TotalSeconds;
|
||||||
var expiresAt = now.Plus(Duration.FromSeconds(expiresIn));
|
var expiresAt = now.Plus(Duration.FromSeconds(expiresIn));
|
||||||
|
|
||||||
@@ -415,7 +459,7 @@ public class OidcProviderService(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Store the code with its metadata in the cache
|
// Store the code with its metadata in the cache
|
||||||
var cacheKey = $"auth:oidc-code:{code}";
|
var cacheKey = $"{CacheKeyPrefixAuthCode}{code}";
|
||||||
await cache.SetAsync(cacheKey, authCodeInfo, _options.AuthorizationCodeLifetime);
|
await cache.SetAsync(cacheKey, authCodeInfo, _options.AuthorizationCodeLifetime);
|
||||||
|
|
||||||
logger.LogInformation("Generated authorization code for client {ClientId} and user {UserId}", clientId, userId);
|
logger.LogInformation("Generated authorization code for client {ClientId} and user {UserId}", clientId, userId);
|
||||||
@@ -429,7 +473,7 @@ public class OidcProviderService(
|
|||||||
string? codeVerifier = null
|
string? codeVerifier = null
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
var cacheKey = $"auth:oidc-code:{code}";
|
var cacheKey = $"{CacheKeyPrefixAuthCode}{code}";
|
||||||
var (found, authCode) = await cache.GetAsyncWithStatus<AuthorizationCodeInfo>(cacheKey);
|
var (found, authCode) = await cache.GetAsyncWithStatus<AuthorizationCodeInfo>(cacheKey);
|
||||||
|
|
||||||
if (!found || authCode == null)
|
if (!found || authCode == null)
|
||||||
@@ -465,8 +509,8 @@ public class OidcProviderService(
|
|||||||
|
|
||||||
var isValid = authCode.CodeChallengeMethod?.ToUpperInvariant() switch
|
var isValid = authCode.CodeChallengeMethod?.ToUpperInvariant() switch
|
||||||
{
|
{
|
||||||
"S256" => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, "S256"),
|
CodeChallengeMethodS256 => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, CodeChallengeMethodS256),
|
||||||
"PLAIN" => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, "PLAIN"),
|
CodeChallengeMethodPlain => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, CodeChallengeMethodPlain),
|
||||||
_ => false // Unsupported code challenge method
|
_ => false // Unsupported code challenge method
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -504,19 +548,12 @@ public class OidcProviderService(
|
|||||||
{
|
{
|
||||||
if (string.IsNullOrEmpty(codeVerifier)) return false;
|
if (string.IsNullOrEmpty(codeVerifier)) return false;
|
||||||
|
|
||||||
if (method == "S256")
|
if (method != CodeChallengeMethodS256)
|
||||||
{
|
return method == CodeChallengeMethodPlain &&
|
||||||
using var sha256 = SHA256.Create();
|
string.Equals(codeVerifier, codeChallenge, StringComparison.Ordinal);
|
||||||
var hash = sha256.ComputeHash(Encoding.UTF8.GetBytes(codeVerifier));
|
var hash = SHA256.HashData(Encoding.UTF8.GetBytes(codeVerifier));
|
||||||
var base64 = Base64UrlEncoder.Encode(hash);
|
var base64 = Base64UrlEncoder.Encode(hash);
|
||||||
|
|
||||||
return string.Equals(base64, codeChallenge, StringComparison.Ordinal);
|
return string.Equals(base64, codeChallenge, StringComparison.Ordinal);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (method == "PLAIN")
|
|
||||||
{
|
|
||||||
return string.Equals(codeVerifier, codeChallenge, StringComparison.Ordinal);
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user