Compare commits
5 Commits
5445df3b61
...
c74ab20236
| Author | SHA1 | Date | |
|---|---|---|---|
|
c74ab20236
|
|||
|
b9edf51f05
|
|||
|
74a9ca98ad
|
|||
|
4bd59f107b
|
|||
|
08f924f647
|
@@ -27,19 +27,47 @@ public class OidcProviderService(
|
||||
{
|
||||
private readonly OidcProviderOptions _options = options.Value;
|
||||
|
||||
private const string CacheKeyPrefixClientId = "auth:oidc-client:id:";
|
||||
private const string CacheKeyPrefixClientSlug = "auth:oidc-client:slug:";
|
||||
private const string CacheKeyPrefixAuthCode = "auth:oidc-code:";
|
||||
private const string CodeChallengeMethodS256 = "S256";
|
||||
private const string CodeChallengeMethodPlain = "PLAIN";
|
||||
|
||||
public async Task<CustomApp?> FindClientByIdAsync(Guid clientId)
|
||||
{
|
||||
var cacheKey = $"{CacheKeyPrefixClientId}{clientId}";
|
||||
var (found, cachedApp) = await cache.GetAsyncWithStatus<CustomApp>(cacheKey);
|
||||
if (found && cachedApp != null)
|
||||
{
|
||||
return cachedApp;
|
||||
}
|
||||
|
||||
var resp = await customApps.GetCustomAppAsync(new GetCustomAppRequest { Id = clientId.ToString() });
|
||||
if (resp.App != null)
|
||||
{
|
||||
await cache.SetAsync(cacheKey, resp.App, TimeSpan.FromMinutes(5));
|
||||
}
|
||||
return resp.App ?? null;
|
||||
}
|
||||
|
||||
public async Task<CustomApp?> FindClientBySlugAsync(string slug)
|
||||
{
|
||||
var cacheKey = $"{CacheKeyPrefixClientSlug}{slug}";
|
||||
var (found, cachedApp) = await cache.GetAsyncWithStatus<CustomApp>(cacheKey);
|
||||
if (found && cachedApp != null)
|
||||
{
|
||||
return cachedApp;
|
||||
}
|
||||
|
||||
var resp = await customApps.GetCustomAppAsync(new GetCustomAppRequest { Slug = slug });
|
||||
if (resp.App != null)
|
||||
{
|
||||
await cache.SetAsync(cacheKey, resp.App, TimeSpan.FromMinutes(5));
|
||||
}
|
||||
return resp.App ?? null;
|
||||
}
|
||||
|
||||
public async Task<SnAuthSession?> FindValidSessionAsync(Guid accountId, Guid clientId, bool withAccount = false)
|
||||
private async Task<SnAuthSession?> FindValidSessionAsync(Guid accountId, Guid clientId, bool withAccount = false)
|
||||
{
|
||||
var now = SystemClock.Instance.GetCurrentInstant();
|
||||
|
||||
@@ -74,70 +102,79 @@ public class OidcProviderService(
|
||||
return resp.Valid;
|
||||
}
|
||||
|
||||
private static bool IsWildcardRedirectUriMatch(string allowedUri, string redirectUri)
|
||||
{
|
||||
if (string.IsNullOrEmpty(allowedUri) || string.IsNullOrEmpty(redirectUri))
|
||||
return false;
|
||||
|
||||
// Check if it's an exact match
|
||||
if (string.Equals(allowedUri, redirectUri, StringComparison.Ordinal))
|
||||
return true;
|
||||
|
||||
// Quick check for wildcard patterns
|
||||
if (!allowedUri.Contains('*'))
|
||||
return false;
|
||||
|
||||
// Parse URIs once
|
||||
Uri? allowedUriObj, redirectUriObj;
|
||||
try
|
||||
{
|
||||
allowedUriObj = new Uri(allowedUri);
|
||||
redirectUriObj = new Uri(redirectUri);
|
||||
}
|
||||
catch (UriFormatException)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check scheme and port matches
|
||||
if (allowedUriObj.Scheme != redirectUriObj.Scheme || allowedUriObj.Port != redirectUriObj.Port)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var allowedHost = allowedUriObj.Host;
|
||||
var redirectHost = redirectUriObj.Host;
|
||||
|
||||
// Handle wildcard domain patterns like *.example.com
|
||||
if (allowedHost.StartsWith("*."))
|
||||
{
|
||||
var baseDomain = allowedHost[2..]; // Remove "*."
|
||||
if (redirectHost == baseDomain || redirectHost.EndsWith("." + baseDomain))
|
||||
{
|
||||
// Check path match
|
||||
var allowedPath = allowedUriObj.AbsolutePath.TrimEnd('/');
|
||||
var redirectPath = redirectUriObj.AbsolutePath.TrimEnd('/');
|
||||
|
||||
// If allowed path is empty, any path is allowed
|
||||
// If allowed path is specified, redirect path must start with it
|
||||
return string.IsNullOrEmpty(allowedPath) ||
|
||||
redirectPath.StartsWith(allowedPath, StringComparison.OrdinalIgnoreCase);
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
public async Task<bool> ValidateRedirectUriAsync(Guid clientId, string redirectUri)
|
||||
{
|
||||
if (string.IsNullOrEmpty(redirectUri))
|
||||
return false;
|
||||
|
||||
|
||||
var client = await FindClientByIdAsync(clientId);
|
||||
if (client?.Status != Shared.Proto.CustomAppStatus.Production)
|
||||
return true;
|
||||
|
||||
if (client?.OauthConfig?.RedirectUris == null)
|
||||
var redirectUris = client?.OauthConfig?.RedirectUris;
|
||||
if (redirectUris == null || redirectUris.Count == 0)
|
||||
return false;
|
||||
|
||||
// Check if the redirect URI matches any of the allowed URIs
|
||||
// For exact match
|
||||
if (client.OauthConfig.RedirectUris.Contains(redirectUri))
|
||||
return true;
|
||||
|
||||
// Check for wildcard matches (e.g., https://*.example.com/*)
|
||||
foreach (var allowedUri in client.OauthConfig.RedirectUris)
|
||||
// Check each allowed URI for a match
|
||||
foreach (var allowedUri in redirectUris)
|
||||
{
|
||||
if (string.IsNullOrEmpty(allowedUri))
|
||||
continue;
|
||||
|
||||
// Handle wildcard in domain
|
||||
if (allowedUri.Contains("*.") && allowedUri.StartsWith("http"))
|
||||
if (IsWildcardRedirectUriMatch(allowedUri, redirectUri))
|
||||
{
|
||||
try
|
||||
{
|
||||
var allowedUriObj = new Uri(allowedUri);
|
||||
var redirectUriObj = new Uri(redirectUri);
|
||||
|
||||
if (allowedUriObj.Scheme != redirectUriObj.Scheme ||
|
||||
allowedUriObj.Port != redirectUriObj.Port)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if the domain matches the wildcard pattern
|
||||
var allowedDomain = allowedUriObj.Host;
|
||||
var redirectDomain = redirectUriObj.Host;
|
||||
|
||||
if (allowedDomain.StartsWith("*."))
|
||||
{
|
||||
var baseDomain = allowedDomain[2..]; // Remove the "*." prefix
|
||||
if (redirectDomain == baseDomain || redirectDomain.EndsWith($".{baseDomain}"))
|
||||
{
|
||||
// Check path
|
||||
var allowedPath = allowedUriObj.AbsolutePath.TrimEnd('/');
|
||||
var redirectPath = redirectUriObj.AbsolutePath.TrimEnd('/');
|
||||
|
||||
if (string.IsNullOrEmpty(allowedPath) ||
|
||||
redirectPath.StartsWith(allowedPath, StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (UriFormatException)
|
||||
{
|
||||
// Invalid URI format in allowed URIs, skip
|
||||
continue;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -219,6 +256,53 @@ public class OidcProviderService(
|
||||
return tokenHandler.WriteToken(token);
|
||||
}
|
||||
|
||||
private async Task<(SnAuthSession session, string? nonce, List<string>? scopes)> HandleAuthorizationCodeFlowAsync(
|
||||
string authorizationCode,
|
||||
Guid clientId,
|
||||
string? redirectUri,
|
||||
string? codeVerifier
|
||||
)
|
||||
{
|
||||
var authCode = await ValidateAuthorizationCodeAsync(authorizationCode, clientId, redirectUri, codeVerifier);
|
||||
if (authCode == null)
|
||||
throw new InvalidOperationException("Invalid authorization code");
|
||||
|
||||
// Load the session for the user
|
||||
var existingSession = await FindValidSessionAsync(authCode.AccountId, clientId, withAccount: true);
|
||||
|
||||
SnAuthSession session;
|
||||
if (existingSession == null)
|
||||
{
|
||||
var account = await db.Accounts
|
||||
.Where(a => a.Id == authCode.AccountId)
|
||||
.Include(a => a.Profile)
|
||||
.Include(a => a.Contacts)
|
||||
.FirstOrDefaultAsync();
|
||||
if (account == null) throw new InvalidOperationException("Account not found");
|
||||
session = await auth.CreateSessionForOidcAsync(account, SystemClock.Instance.GetCurrentInstant(), clientId);
|
||||
session.Account = account;
|
||||
}
|
||||
else
|
||||
{
|
||||
session = existingSession;
|
||||
}
|
||||
|
||||
return (session, authCode.Nonce, authCode.Scopes);
|
||||
}
|
||||
|
||||
private async Task<(SnAuthSession session, string? nonce, List<string>? scopes)> HandleRefreshTokenFlowAsync(Guid sessionId)
|
||||
{
|
||||
var session = await FindSessionByIdAsync(sessionId) ??
|
||||
throw new InvalidOperationException("Session not found");
|
||||
|
||||
// Verify the session is still valid
|
||||
var now = SystemClock.Instance.GetCurrentInstant();
|
||||
if (session.ExpiredAt.HasValue && session.ExpiredAt < now)
|
||||
throw new InvalidOperationException("Session has expired");
|
||||
|
||||
return (session, null, null);
|
||||
}
|
||||
|
||||
public async Task<TokenResponse> GenerateTokenResponseAsync(
|
||||
Guid clientId,
|
||||
string? authorizationCode = null,
|
||||
@@ -227,58 +311,18 @@ public class OidcProviderService(
|
||||
Guid? sessionId = null
|
||||
)
|
||||
{
|
||||
if (clientId == Guid.Empty) throw new ArgumentException("Client ID cannot be empty", nameof(clientId));
|
||||
|
||||
var client = await FindClientByIdAsync(clientId) ?? throw new InvalidOperationException("Client not found");
|
||||
|
||||
SnAuthSession session;
|
||||
var (session, nonce, scopes) = authorizationCode != null
|
||||
? await HandleAuthorizationCodeFlowAsync(authorizationCode, clientId, redirectUri, codeVerifier)
|
||||
: sessionId.HasValue
|
||||
? await HandleRefreshTokenFlowAsync(sessionId.Value)
|
||||
: throw new InvalidOperationException("Either authorization code or session ID must be provided");
|
||||
|
||||
var clock = SystemClock.Instance;
|
||||
var now = clock.GetCurrentInstant();
|
||||
string? nonce = null;
|
||||
List<string>? scopes = null;
|
||||
|
||||
if (authorizationCode != null)
|
||||
{
|
||||
// Authorization code flow
|
||||
var authCode = await ValidateAuthorizationCodeAsync(authorizationCode, clientId, redirectUri, codeVerifier);
|
||||
if (authCode == null)
|
||||
throw new InvalidOperationException("Invalid authorization code");
|
||||
|
||||
// Load the session for the user
|
||||
var existingSession = await FindValidSessionAsync(authCode.AccountId, clientId, withAccount: true);
|
||||
|
||||
if (existingSession is null)
|
||||
{
|
||||
var account = await db.Accounts
|
||||
.Where(a => a.Id == authCode.AccountId)
|
||||
.Include(a => a.Profile)
|
||||
.Include(a => a.Contacts)
|
||||
.FirstOrDefaultAsync();
|
||||
if (account is null) throw new InvalidOperationException("Account not found");
|
||||
session = await auth.CreateSessionForOidcAsync(account, clock.GetCurrentInstant(), clientId);
|
||||
session.Account = account;
|
||||
}
|
||||
else
|
||||
{
|
||||
session = existingSession;
|
||||
}
|
||||
|
||||
scopes = authCode.Scopes;
|
||||
nonce = authCode.Nonce;
|
||||
}
|
||||
else if (sessionId.HasValue)
|
||||
{
|
||||
// Refresh token flow
|
||||
session = await FindSessionByIdAsync(sessionId.Value) ??
|
||||
throw new InvalidOperationException("Session not found");
|
||||
|
||||
// Verify the session is still valid
|
||||
if (session.ExpiredAt < now)
|
||||
throw new InvalidOperationException("Session has expired");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("Either authorization code or session ID must be provided");
|
||||
}
|
||||
|
||||
var expiresIn = (int)_options.AccessTokenLifetime.TotalSeconds;
|
||||
var expiresAt = now.Plus(Duration.FromSeconds(expiresIn));
|
||||
|
||||
@@ -415,7 +459,7 @@ public class OidcProviderService(
|
||||
};
|
||||
|
||||
// Store the code with its metadata in the cache
|
||||
var cacheKey = $"auth:oidc-code:{code}";
|
||||
var cacheKey = $"{CacheKeyPrefixAuthCode}{code}";
|
||||
await cache.SetAsync(cacheKey, authCodeInfo, _options.AuthorizationCodeLifetime);
|
||||
|
||||
logger.LogInformation("Generated authorization code for client {ClientId} and user {UserId}", clientId, userId);
|
||||
@@ -429,7 +473,7 @@ public class OidcProviderService(
|
||||
string? codeVerifier = null
|
||||
)
|
||||
{
|
||||
var cacheKey = $"auth:oidc-code:{code}";
|
||||
var cacheKey = $"{CacheKeyPrefixAuthCode}{code}";
|
||||
var (found, authCode) = await cache.GetAsyncWithStatus<AuthorizationCodeInfo>(cacheKey);
|
||||
|
||||
if (!found || authCode == null)
|
||||
@@ -465,8 +509,8 @@ public class OidcProviderService(
|
||||
|
||||
var isValid = authCode.CodeChallengeMethod?.ToUpperInvariant() switch
|
||||
{
|
||||
"S256" => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, "S256"),
|
||||
"PLAIN" => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, "PLAIN"),
|
||||
CodeChallengeMethodS256 => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, CodeChallengeMethodS256),
|
||||
CodeChallengeMethodPlain => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, CodeChallengeMethodPlain),
|
||||
_ => false // Unsupported code challenge method
|
||||
};
|
||||
|
||||
@@ -504,19 +548,12 @@ public class OidcProviderService(
|
||||
{
|
||||
if (string.IsNullOrEmpty(codeVerifier)) return false;
|
||||
|
||||
if (method == "S256")
|
||||
{
|
||||
using var sha256 = SHA256.Create();
|
||||
var hash = sha256.ComputeHash(Encoding.UTF8.GetBytes(codeVerifier));
|
||||
var base64 = Base64UrlEncoder.Encode(hash);
|
||||
return string.Equals(base64, codeChallenge, StringComparison.Ordinal);
|
||||
}
|
||||
if (method != CodeChallengeMethodS256)
|
||||
return method == CodeChallengeMethodPlain &&
|
||||
string.Equals(codeVerifier, codeChallenge, StringComparison.Ordinal);
|
||||
var hash = SHA256.HashData(Encoding.UTF8.GetBytes(codeVerifier));
|
||||
var base64 = Base64UrlEncoder.Encode(hash);
|
||||
|
||||
if (method == "PLAIN")
|
||||
{
|
||||
return string.Equals(codeVerifier, codeChallenge, StringComparison.Ordinal);
|
||||
}
|
||||
|
||||
return false;
|
||||
return string.Equals(base64, codeChallenge, StringComparison.Ordinal);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,11 @@ public class AfdianOidcService(
|
||||
protected override string DiscoveryEndpoint => ""; // Afdian doesn't have a standard OIDC discovery endpoint
|
||||
protected override string ConfigSectionName => "Afdian";
|
||||
|
||||
public override Task<string> GetAuthorizationUrlAsync(string state, string nonce)
|
||||
{
|
||||
return Task.FromResult(GetAuthorizationUrl(state, nonce));
|
||||
}
|
||||
|
||||
public override string GetAuthorizationUrl(string state, string nonce)
|
||||
{
|
||||
var config = GetProviderConfig();
|
||||
@@ -90,4 +95,4 @@ public class AfdianOidcService(
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,30 @@ public class AppleOidcService(
|
||||
protected override string DiscoveryEndpoint => "https://appleid.apple.com/.well-known/openid-configuration";
|
||||
protected override string ConfigSectionName => "Apple";
|
||||
|
||||
public override async Task<string> GetAuthorizationUrlAsync(string state, string nonce)
|
||||
{
|
||||
var config = GetProviderConfig();
|
||||
var discoveryDocument = await GetDiscoveryDocumentAsync();
|
||||
|
||||
if (discoveryDocument?.AuthorizationEndpoint == null)
|
||||
{
|
||||
throw new InvalidOperationException("Authorization endpoint not found in discovery document");
|
||||
}
|
||||
|
||||
var queryParams = BuildAuthorizationParameters(
|
||||
config.ClientId,
|
||||
config.RedirectUri,
|
||||
"name email",
|
||||
"code id_token",
|
||||
state,
|
||||
nonce,
|
||||
"form_post"
|
||||
);
|
||||
|
||||
var queryString = string.Join("&", queryParams.Select(p => $"{p.Key}={Uri.EscapeDataString(p.Value)}"));
|
||||
return $"{discoveryDocument.AuthorizationEndpoint}?{queryString}";
|
||||
}
|
||||
|
||||
public override string GetAuthorizationUrl(string state, string nonce)
|
||||
{
|
||||
var config = GetProviderConfig();
|
||||
@@ -276,4 +300,4 @@ public class AppleKey
|
||||
|
||||
return Convert.FromBase64String(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,11 @@ public class DiscordOidcService(
|
||||
protected override string DiscoveryEndpoint => ""; // Discord doesn't have a standard OIDC discovery endpoint
|
||||
protected override string ConfigSectionName => "Discord";
|
||||
|
||||
public override Task<string> GetAuthorizationUrlAsync(string state, string nonce)
|
||||
{
|
||||
return Task.FromResult(GetAuthorizationUrl(state, nonce));
|
||||
}
|
||||
|
||||
public override string GetAuthorizationUrl(string state, string nonce)
|
||||
{
|
||||
var config = GetProviderConfig();
|
||||
@@ -111,4 +116,4 @@ public class DiscordOidcService(
|
||||
Provider = ProviderName
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,11 @@ public class GitHubOidcService(
|
||||
protected override string DiscoveryEndpoint => ""; // GitHub doesn't have a standard OIDC discovery endpoint
|
||||
protected override string ConfigSectionName => "GitHub";
|
||||
|
||||
public override Task<string> GetAuthorizationUrlAsync(string state, string nonce)
|
||||
{
|
||||
return Task.FromResult(GetAuthorizationUrl(state, nonce));
|
||||
}
|
||||
|
||||
public override string GetAuthorizationUrl(string state, string nonce)
|
||||
{
|
||||
var config = GetProviderConfig();
|
||||
@@ -123,4 +128,4 @@ public class GitHubOidcService(
|
||||
public bool Primary { get; set; }
|
||||
public bool Verified { get; set; }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,25 +19,36 @@ public class GoogleOidcService(
|
||||
protected override string DiscoveryEndpoint => "https://accounts.google.com/.well-known/openid-configuration";
|
||||
protected override string ConfigSectionName => "Google";
|
||||
|
||||
public override string GetAuthorizationUrl(string state, string nonce)
|
||||
public override async Task<string> GetAuthorizationUrlAsync(string state, string nonce)
|
||||
{
|
||||
var config = GetProviderConfig();
|
||||
var discoveryDocument = GetDiscoveryDocumentAsync().GetAwaiter().GetResult();
|
||||
var discoveryDocument = await GetDiscoveryDocumentAsync();
|
||||
|
||||
if (discoveryDocument?.AuthorizationEndpoint == null)
|
||||
{
|
||||
throw new InvalidOperationException("Authorization endpoint not found in discovery document");
|
||||
}
|
||||
|
||||
var queryParams = new Dictionary<string, string>
|
||||
{
|
||||
{ "client_id", config.ClientId },
|
||||
{ "redirect_uri", config.RedirectUri },
|
||||
{ "response_type", "code" },
|
||||
{ "scope", "openid email profile" },
|
||||
{ "state", state }, // No '|codeVerifier' appended anymore
|
||||
{ "nonce", nonce }
|
||||
};
|
||||
// Generate PKCE code verifier and challenge for enhanced security
|
||||
var codeVerifier = GenerateCodeVerifier();
|
||||
var codeChallenge = GenerateCodeChallenge(codeVerifier);
|
||||
|
||||
var queryParams = BuildAuthorizationParameters(
|
||||
config.ClientId,
|
||||
config.RedirectUri,
|
||||
"openid email profile",
|
||||
"code",
|
||||
state,
|
||||
nonce
|
||||
);
|
||||
|
||||
// Add PKCE parameters
|
||||
queryParams["code_challenge"] = codeChallenge;
|
||||
queryParams["code_challenge_method"] = "S256";
|
||||
|
||||
// Store code verifier in cache for later token exchange
|
||||
var codeVerifierKey = $"pkce:{state}";
|
||||
await cache.SetAsync(codeVerifierKey, codeVerifier, TimeSpan.FromMinutes(15));
|
||||
|
||||
var queryString = string.Join("&", queryParams.Select(p => $"{p.Key}={Uri.EscapeDataString(p.Value)}"));
|
||||
return $"{discoveryDocument.AuthorizationEndpoint}?{queryString}";
|
||||
@@ -45,89 +56,34 @@ public class GoogleOidcService(
|
||||
|
||||
public override async Task<OidcUserInfo> ProcessCallbackAsync(OidcCallbackData callbackData)
|
||||
{
|
||||
// No need to split or parse code verifier from state
|
||||
var state = callbackData.State ?? "";
|
||||
callbackData.State = state; // Keep the original state if needed
|
||||
|
||||
// Exchange the code for tokens
|
||||
// Pass null or omit the parameter for codeVerifier as PKCE is removed
|
||||
var tokenResponse = await ExchangeCodeForTokensAsync(callbackData.Code, null);
|
||||
if (tokenResponse?.IdToken == null)
|
||||
// Retrieve PKCE code verifier from cache
|
||||
var codeVerifierKey = $"pkce:{state}";
|
||||
var (found, codeVerifier) = await cache.GetAsyncWithStatus<string>(codeVerifierKey);
|
||||
if (!found || string.IsNullOrEmpty(codeVerifier))
|
||||
{
|
||||
throw new InvalidOperationException("Failed to obtain ID token from Google");
|
||||
throw new InvalidOperationException("PKCE code verifier not found or expired");
|
||||
}
|
||||
|
||||
// Validate the ID token
|
||||
var userInfo = await ValidateTokenAsync(tokenResponse.IdToken);
|
||||
// Remove the code verifier from cache to prevent replay attacks
|
||||
await cache.RemoveAsync(codeVerifierKey);
|
||||
|
||||
// Set tokens on the user info
|
||||
userInfo.AccessToken = tokenResponse.AccessToken;
|
||||
userInfo.RefreshToken = tokenResponse.RefreshToken;
|
||||
|
||||
// Try to fetch additional profile data if userinfo endpoint is available
|
||||
try
|
||||
// Exchange the code for tokens using PKCE
|
||||
var tokenResponse = await ExchangeCodeForTokensAsync(callbackData.Code, codeVerifier);
|
||||
if (tokenResponse == null)
|
||||
{
|
||||
var discoveryDocument = await GetDiscoveryDocumentAsync();
|
||||
if (discoveryDocument?.UserinfoEndpoint != null && !string.IsNullOrEmpty(tokenResponse.AccessToken))
|
||||
{
|
||||
var client = _httpClientFactory.CreateClient();
|
||||
client.DefaultRequestHeaders.Authorization =
|
||||
new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", tokenResponse.AccessToken);
|
||||
|
||||
var userInfoResponse =
|
||||
await client.GetFromJsonAsync<Dictionary<string, object>>(discoveryDocument.UserinfoEndpoint);
|
||||
|
||||
if (userInfoResponse != null)
|
||||
{
|
||||
if (userInfoResponse.TryGetValue("picture", out var picture) && picture != null)
|
||||
{
|
||||
userInfo.ProfilePictureUrl = picture.ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
catch
|
||||
{
|
||||
// Ignore errors when fetching additional profile data
|
||||
throw new InvalidOperationException("Failed to exchange code for tokens");
|
||||
}
|
||||
|
||||
return userInfo;
|
||||
}
|
||||
|
||||
private async Task<OidcUserInfo> ValidateTokenAsync(string idToken)
|
||||
{
|
||||
// Use the strategy pattern to retrieve user info
|
||||
var discoveryDocument = await GetDiscoveryDocumentAsync();
|
||||
if (discoveryDocument?.JwksUri == null)
|
||||
{
|
||||
throw new InvalidOperationException("JWKS URI not found in discovery document");
|
||||
}
|
||||
var config = GetProviderConfig();
|
||||
var strategy = new IdTokenValidationStrategy(_httpClientFactory);
|
||||
|
||||
var client = _httpClientFactory.CreateClient();
|
||||
var jwksResponse = await client.GetFromJsonAsync<JsonWebKeySet>(discoveryDocument.JwksUri);
|
||||
if (jwksResponse == null)
|
||||
{
|
||||
throw new InvalidOperationException("Failed to retrieve JWKS from Google");
|
||||
}
|
||||
|
||||
var handler = new JwtSecurityTokenHandler();
|
||||
var jwtToken = handler.ReadJwtToken(idToken);
|
||||
var kid = jwtToken.Header.Kid;
|
||||
var signingKey = jwksResponse.Keys.FirstOrDefault(k => k.Kid == kid);
|
||||
if (signingKey == null)
|
||||
{
|
||||
throw new SecurityTokenValidationException("Unable to find matching key in Google's JWKS");
|
||||
}
|
||||
|
||||
var validationParameters = new TokenValidationParameters
|
||||
{
|
||||
ValidateIssuer = true,
|
||||
ValidIssuer = "https://accounts.google.com",
|
||||
ValidateAudience = true,
|
||||
ValidAudience = GetProviderConfig().ClientId,
|
||||
ValidateLifetime = true,
|
||||
IssuerSigningKey = signingKey
|
||||
};
|
||||
|
||||
return ValidateAndExtractIdToken(idToken, validationParameters);
|
||||
return await strategy.GetUserInfoAsync(tokenResponse, discoveryDocument, config.ClientId, ProviderName);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -20,6 +20,27 @@ public class MicrosoftOidcService(
|
||||
|
||||
protected override string ConfigSectionName => "Microsoft";
|
||||
|
||||
public override async Task<string> GetAuthorizationUrlAsync(string state, string nonce)
|
||||
{
|
||||
var config = GetProviderConfig();
|
||||
var discoveryDocument = await GetDiscoveryDocumentAsync();
|
||||
|
||||
if (discoveryDocument?.AuthorizationEndpoint == null)
|
||||
throw new InvalidOperationException("Authorization endpoint not found in discovery document.");
|
||||
|
||||
var queryParams = BuildAuthorizationParameters(
|
||||
config.ClientId,
|
||||
config.RedirectUri,
|
||||
"openid profile email",
|
||||
"code",
|
||||
state,
|
||||
nonce
|
||||
);
|
||||
|
||||
var queryString = string.Join("&", queryParams.Select(p => $"{p.Key}={Uri.EscapeDataString(p.Value)}"));
|
||||
return $"{discoveryDocument.AuthorizationEndpoint}?{queryString}";
|
||||
}
|
||||
|
||||
public override string GetAuthorizationUrl(string state, string nonce)
|
||||
{
|
||||
var config = GetProviderConfig();
|
||||
@@ -120,4 +141,4 @@ public class MicrosoftOidcService(
|
||||
Provider = ProviderName
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ public class OidcController(
|
||||
await cache.SetAsync($"{StateCachePrefix}{state}", oidcState, StateExpiration);
|
||||
|
||||
// The state parameter sent to the provider is the GUID key for the cache.
|
||||
var authUrl = oidcService.GetAuthorizationUrl(state, nonce);
|
||||
var authUrl = await oidcService.GetAuthorizationUrlAsync(state, nonce);
|
||||
return Redirect(authUrl);
|
||||
}
|
||||
else // Otherwise, proceed with the login / registration flow
|
||||
@@ -54,7 +54,7 @@ public class OidcController(
|
||||
// Create login state with return URL and device ID
|
||||
var oidcState = OidcState.ForLogin(returnUrl ?? "/", deviceId);
|
||||
await cache.SetAsync($"{StateCachePrefix}{state}", oidcState, StateExpiration);
|
||||
var authUrl = oidcService.GetAuthorizationUrl(state, nonce);
|
||||
var authUrl = await oidcService.GetAuthorizationUrlAsync(state, nonce);
|
||||
return Redirect(authUrl);
|
||||
}
|
||||
}
|
||||
@@ -194,4 +194,4 @@ public class OidcController(
|
||||
|
||||
return newAccount;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
using System;
|
||||
using System.IdentityModel.Tokens.Jwt;
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
using System.Text.Json.Serialization;
|
||||
using DysonNetwork.Shared.Cache;
|
||||
using DysonNetwork.Shared.Models;
|
||||
@@ -39,9 +42,40 @@ public abstract class OidcService(
|
||||
protected abstract string ConfigSectionName { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the authorization URL for initiating the authentication flow
|
||||
/// Gets the authorization URL for initiating the authentication flow (async)
|
||||
/// </summary>
|
||||
public abstract string GetAuthorizationUrl(string state, string nonce);
|
||||
public abstract Task<string> GetAuthorizationUrlAsync(string state, string nonce);
|
||||
|
||||
/// <summary>
|
||||
/// Gets the authorization URL for initiating the authentication flow (sync for backward compatibility)
|
||||
/// </summary>
|
||||
public virtual string GetAuthorizationUrl(string state, string nonce)
|
||||
{
|
||||
return GetAuthorizationUrlAsync(state, nonce).GetAwaiter().GetResult();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Builds common authorization URL query parameters
|
||||
/// </summary>
|
||||
protected Dictionary<string, string> BuildAuthorizationParameters(string clientId, string redirectUri, string scope, string responseType, string state, string nonce, string? responseMode = null)
|
||||
{
|
||||
var parameters = new Dictionary<string, string>
|
||||
{
|
||||
["client_id"] = clientId,
|
||||
["redirect_uri"] = redirectUri,
|
||||
["response_type"] = responseType,
|
||||
["scope"] = scope,
|
||||
["state"] = state,
|
||||
["nonce"] = nonce
|
||||
};
|
||||
|
||||
if (!string.IsNullOrEmpty(responseMode))
|
||||
{
|
||||
parameters["response_mode"] = responseMode;
|
||||
}
|
||||
|
||||
return parameters;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Process the callback from the OIDC provider
|
||||
@@ -61,6 +95,38 @@ public abstract class OidcService(
|
||||
};
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Generates a cryptographically secure random code verifier for PKCE
|
||||
/// </summary>
|
||||
protected static string GenerateCodeVerifier()
|
||||
{
|
||||
// Generate a 32-byte (256-bit) random byte array
|
||||
var randomBytes = new byte[32];
|
||||
RandomNumberGenerator.Fill(randomBytes);
|
||||
|
||||
// Convert to URL-safe base64 (no padding)
|
||||
return Convert.ToBase64String(randomBytes)
|
||||
.Replace("+", "-")
|
||||
.Replace("/", "_")
|
||||
.TrimEnd('=');
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Generates the code challenge from a code verifier using S256 method
|
||||
/// </summary>
|
||||
protected static string GenerateCodeChallenge(string codeVerifier)
|
||||
{
|
||||
using var sha256 = SHA256.Create();
|
||||
var bytes = Encoding.UTF8.GetBytes(codeVerifier);
|
||||
var hash = sha256.ComputeHash(bytes);
|
||||
|
||||
// Convert to URL-safe base64 (no padding)
|
||||
return Convert.ToBase64String(hash)
|
||||
.Replace("+", "-")
|
||||
.Replace("/", "_")
|
||||
.TrimEnd('=');
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Retrieves the OpenID Connect discovery document
|
||||
/// </summary>
|
||||
|
||||
236
DysonNetwork.Pass/Auth/OpenId/UserInfoStrategies.cs
Normal file
236
DysonNetwork.Pass/Auth/OpenId/UserInfoStrategies.cs
Normal file
@@ -0,0 +1,236 @@
|
||||
using System.IdentityModel.Tokens.Jwt;
|
||||
using System.Security.Cryptography;
|
||||
using System.Text.Json;
|
||||
using Microsoft.EntityFrameworkCore;
|
||||
using Microsoft.IdentityModel.Tokens;
|
||||
|
||||
namespace DysonNetwork.Pass.Auth.OpenId;
|
||||
|
||||
/// <summary>
|
||||
/// Defines how to retrieve user information from an OIDC provider
|
||||
/// </summary>
|
||||
public interface IUserInfoStrategy
|
||||
{
|
||||
/// <summary>
|
||||
/// Retrieves user information using the provided token response and discovery document
|
||||
/// </summary>
|
||||
Task<OidcUserInfo> GetUserInfoAsync(OidcTokenResponse tokenResponse, OidcDiscoveryDocument? discoveryDocument,
|
||||
string clientId, string providerName);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Strategy for validating and extracting user info from ID tokens (Google, Apple)
|
||||
/// </summary>
|
||||
public class IdTokenValidationStrategy : IUserInfoStrategy
|
||||
{
|
||||
private readonly IHttpClientFactory _httpClientFactory;
|
||||
|
||||
public IdTokenValidationStrategy(IHttpClientFactory httpClientFactory)
|
||||
{
|
||||
_httpClientFactory = httpClientFactory;
|
||||
}
|
||||
|
||||
public async Task<OidcUserInfo> GetUserInfoAsync(OidcTokenResponse tokenResponse, OidcDiscoveryDocument? discoveryDocument,
|
||||
string clientId, string providerName)
|
||||
{
|
||||
if (string.IsNullOrEmpty(tokenResponse.IdToken))
|
||||
throw new InvalidOperationException("ID token not found in response");
|
||||
|
||||
// Determine issuer and validation parameters based on provider
|
||||
var (issuer, jwksUri) = providerName.ToLower() switch
|
||||
{
|
||||
"google" => ("https://accounts.google.com",
|
||||
discoveryDocument?.JwksUri ?? "https://www.googleapis.com/oauth2/v3/certs"),
|
||||
"apple" => ("https://appleid.apple.com",
|
||||
"https://appleid.apple.com/auth/keys"),
|
||||
_ => throw new NotSupportedException($"ID token validation not supported for provider: {providerName}")
|
||||
};
|
||||
|
||||
// Get and validate the token
|
||||
var jwksJson = await GetJwksAsync(jwksUri);
|
||||
var userInfo = await ValidateIdTokenAsync(tokenResponse.IdToken, clientId, issuer, jwksJson, providerName);
|
||||
|
||||
// Set tokens on the user info
|
||||
userInfo.AccessToken = tokenResponse.AccessToken;
|
||||
userInfo.RefreshToken = tokenResponse.RefreshToken;
|
||||
|
||||
// For Google, try to fetch additional profile data
|
||||
if (providerName.ToLower() == "google" && discoveryDocument?.UserinfoEndpoint != null
|
||||
&& !string.IsNullOrEmpty(tokenResponse.AccessToken))
|
||||
{
|
||||
await FetchAdditionalProfileDataAsync(userInfo, discoveryDocument.UserinfoEndpoint,
|
||||
tokenResponse.AccessToken);
|
||||
}
|
||||
|
||||
// For Apple, parse additional user data if provided
|
||||
if (providerName.ToLower() == "apple")
|
||||
{
|
||||
// Apple-specific handling would go here
|
||||
}
|
||||
|
||||
return userInfo;
|
||||
}
|
||||
|
||||
private async Task<string> GetJwksAsync(string jwksUri)
|
||||
{
|
||||
var client = _httpClientFactory.CreateClient();
|
||||
var response = await client.GetAsync(jwksUri);
|
||||
response.EnsureSuccessStatusCode();
|
||||
return await response.Content.ReadAsStringAsync();
|
||||
}
|
||||
|
||||
private async Task<OidcUserInfo> ValidateIdTokenAsync(string idToken, string clientId, string issuer,
|
||||
string jwksJson, string providerName)
|
||||
{
|
||||
var jwks = JsonSerializer.Deserialize<JsonWebKeySet>(jwksJson)
|
||||
?? throw new InvalidOperationException("Failed to parse JWKS");
|
||||
|
||||
var handler = new JwtSecurityTokenHandler();
|
||||
var jwtToken = handler.ReadJwtToken(idToken);
|
||||
var kid = jwtToken.Header.Kid;
|
||||
|
||||
var signingKey = jwks.Keys.FirstOrDefault(k => k.Kid == kid)
|
||||
?? throw new SecurityTokenValidationException($"Unable to find key {kid} in JWKS");
|
||||
|
||||
var validationParameters = new TokenValidationParameters
|
||||
{
|
||||
ValidateIssuer = true,
|
||||
ValidIssuer = issuer,
|
||||
ValidateAudience = true,
|
||||
ValidAudience = clientId,
|
||||
ValidateLifetime = true,
|
||||
IssuerSigningKey = signingKey
|
||||
};
|
||||
|
||||
handler.ValidateToken(idToken, validationParameters, out _);
|
||||
return ExtractUserInfoFromJwt(jwtToken, providerName);
|
||||
}
|
||||
|
||||
private OidcUserInfo ExtractUserInfoFromJwt(JwtSecurityToken jwtToken, string providerName)
|
||||
{
|
||||
var userId = jwtToken.Claims.FirstOrDefault(c => c.Type == "sub")?.Value;
|
||||
var email = jwtToken.Claims.FirstOrDefault(c => c.Type == "email")?.Value;
|
||||
var emailVerified = jwtToken.Claims.FirstOrDefault(c => c.Type == "email_verified")?.Value == "true";
|
||||
var name = jwtToken.Claims.FirstOrDefault(c => c.Type == "name")?.Value;
|
||||
var givenName = jwtToken.Claims.FirstOrDefault(c => c.Type == "given_name")?.Value;
|
||||
var familyName = jwtToken.Claims.FirstOrDefault(c => c.Type == "family_name")?.Value;
|
||||
var preferredUsername = jwtToken.Claims.FirstOrDefault(c => c.Type == "preferred_username")?.Value;
|
||||
var picture = jwtToken.Claims.FirstOrDefault(c => c.Type == "picture")?.Value;
|
||||
|
||||
// Determine preferred username - try different options
|
||||
var username = preferredUsername;
|
||||
if (string.IsNullOrEmpty(username))
|
||||
{
|
||||
// Fall back to email local part if no preferred username
|
||||
username = !string.IsNullOrEmpty(email) ? email.Split('@')[0] : null;
|
||||
}
|
||||
|
||||
return new OidcUserInfo
|
||||
{
|
||||
UserId = userId,
|
||||
Email = email,
|
||||
EmailVerified = emailVerified,
|
||||
FirstName = givenName ?? "",
|
||||
LastName = familyName ?? "",
|
||||
DisplayName = name ?? $"{givenName} {familyName}".Trim(),
|
||||
PreferredUsername = username ?? "",
|
||||
ProfilePictureUrl = picture,
|
||||
Provider = providerName
|
||||
};
|
||||
}
|
||||
|
||||
private async Task FetchAdditionalProfileDataAsync(OidcUserInfo userInfo, string userinfoEndpoint, string accessToken)
|
||||
{
|
||||
try
|
||||
{
|
||||
var client = _httpClientFactory.CreateClient();
|
||||
client.DefaultRequestHeaders.Authorization =
|
||||
new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", accessToken);
|
||||
|
||||
var userInfoResponse = await client.GetFromJsonAsync<Dictionary<string, object>>(userinfoEndpoint);
|
||||
if (userInfoResponse != null)
|
||||
{
|
||||
if (userInfoResponse.TryGetValue("picture", out var picture) && picture != null)
|
||||
{
|
||||
userInfo.ProfilePictureUrl = picture.ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
catch
|
||||
{
|
||||
// Ignore errors when fetching additional profile data
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Strategy for fetching user info from OAuth 2.0 userinfo endpoints (Microsoft, Discord, GitHub)
|
||||
/// </summary>
|
||||
public class UserInfoEndpointStrategy : IUserInfoStrategy
|
||||
{
|
||||
private readonly IHttpClientFactory _httpClientFactory;
|
||||
private readonly Func<JsonElement, OidcUserInfo> _parseUserInfo;
|
||||
private readonly string? _userAgent;
|
||||
|
||||
public UserInfoEndpointStrategy(IHttpClientFactory httpClientFactory,
|
||||
Func<JsonElement, OidcUserInfo> parseUserInfo, string? userAgent = null)
|
||||
{
|
||||
_httpClientFactory = httpClientFactory;
|
||||
_parseUserInfo = parseUserInfo;
|
||||
_userAgent = userAgent;
|
||||
}
|
||||
|
||||
public async Task<OidcUserInfo> GetUserInfoAsync(OidcTokenResponse tokenResponse, OidcDiscoveryDocument? discoveryDocument,
|
||||
string clientId, string providerName)
|
||||
{
|
||||
if (string.IsNullOrEmpty(tokenResponse.AccessToken) || string.IsNullOrEmpty(discoveryDocument?.UserinfoEndpoint))
|
||||
throw new InvalidOperationException("Access token or userinfo endpoint missing");
|
||||
|
||||
var client = _httpClientFactory.CreateClient();
|
||||
var request = new HttpRequestMessage(HttpMethod.Get, discoveryDocument.UserinfoEndpoint);
|
||||
request.Headers.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", tokenResponse.AccessToken);
|
||||
if (!string.IsNullOrEmpty(_userAgent))
|
||||
request.Headers.Add("User-Agent", _userAgent);
|
||||
|
||||
var response = await client.SendAsync(request);
|
||||
response.EnsureSuccessStatusCode();
|
||||
|
||||
var json = await response.Content.ReadAsStringAsync();
|
||||
var userElement = JsonDocument.Parse(json).RootElement;
|
||||
|
||||
var userInfo = _parseUserInfo(userElement);
|
||||
userInfo.AccessToken = tokenResponse.AccessToken;
|
||||
userInfo.RefreshToken = tokenResponse.RefreshToken;
|
||||
userInfo.Provider = providerName;
|
||||
|
||||
return userInfo;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Strategy for extracting user info directly from token responses (Afdian)
|
||||
/// </summary>
|
||||
public class DirectTokenResponseStrategy : IUserInfoStrategy
|
||||
{
|
||||
public Task<OidcUserInfo> GetUserInfoAsync(OidcTokenResponse tokenResponse, OidcDiscoveryDocument? discoveryDocument,
|
||||
string clientId, string providerName)
|
||||
{
|
||||
// Parse user info directly from token response data
|
||||
// This would depend on how the specific provider returns user data
|
||||
if (string.IsNullOrEmpty(tokenResponse.AccessToken))
|
||||
throw new InvalidOperationException("Access token missing");
|
||||
|
||||
// For Afdian, the user data is embedded in the initial token response
|
||||
// This strategy would need to know how to parse that specific format
|
||||
|
||||
var userInfo = new OidcUserInfo
|
||||
{
|
||||
AccessToken = tokenResponse.AccessToken,
|
||||
RefreshToken = tokenResponse.RefreshToken,
|
||||
Provider = providerName,
|
||||
// Parse user data from token response content...
|
||||
};
|
||||
|
||||
return Task.FromResult(userInfo);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user