Compare commits

...

5 Commits

Author SHA1 Message Date
c74ab20236 ♻️ Refactor OpenID: Phase 4: Advanced Architecture - Strategy Pattern Implementation
- Added comprehensive user info strategy pattern with IUserInfoStrategy interface
- Created IdTokenValidationStrategy for Google/Apple ID token validation and parsing
- Implemented UserInfoEndpointStrategy for Microsoft/Discord/GitHub OAuth user data retrieval
- Added DirectTokenResponseStrategy placeholder for Afdian and similar providers
- Updated GoogleOidcService to use IdTokenValidationStrategy instead of custom callback logic
- Centralized JWT token validation, claim extraction, and user data parsing logic
- Eliminated code duplication across providers while maintaining provider-specific behavior
- Improved maintainability by separating concerns of user data retrieval methods
- Set architectural foundation for easily adding new OIDC providers by implementing appropriate strategies
2025-11-02 15:05:42 +08:00
b9edf51f05 ♻️ Refactor OpenID: Phase 3: Async Flow Modernization
- Added async GetAuthorizationUrlAsync() methods to all OIDC providers
- Updated base OidcService with abstract async contract and backward-compatible sync wrapper
- Modified OidcController to use async authorization URL generation
- Removed sync blocks using .GetAwaiter().GetResult() in Google provider
- Maintained backward compatibility with existing sync method calls
- Eliminated thread blocking and improved async flow throughout auth pipeline
- Enhanced scalability by allowing non-blocking async authorization URL generation
2025-11-02 15:05:38 +08:00
74a9ca98ad ♻️ Refactor OpenID: Phase 2: Security Hardening - PKCE Implementation
- Added GenerateCodeVerifier() and GenerateCodeChallenge() methods to base OidcService
- Implemented PKCE (Proof Key for Code Exchange) for Google OAuth flow:
  * Generate cryptographically secure code verifier (256-bit random)
  * Create SHA-256 code challenge for authorization request
  * Cache code verifier with 15-minute expiration for token exchange
  * Validate and remove code verifier during callback to prevent replay attacks
- Enhances security by protecting against authorization code interception attacks
- Uses S256 (SHA-256) code challenge method as per RFC 7636
2025-11-02 15:05:19 +08:00
4bd59f107b ♻️ Refactor OpenID: Phase 1: Code Consolidation optimizations
- Add BuildAuthorizationParameters() method to reduce authorization URL duplication
- Update GoogleOidcService to use common parameter building method
- Add missing using statements for AppDatabase and AuthService namespaces
- Improve code reusability and eliminate 20+ lines of repeated authorization logic per provider
2025-11-02 15:05:04 +08:00
08f924f647 💄 Optimize oidc provider 2025-11-02 14:35:02 +08:00
10 changed files with 566 additions and 211 deletions

View File

@@ -27,19 +27,47 @@ public class OidcProviderService(
{
private readonly OidcProviderOptions _options = options.Value;
private const string CacheKeyPrefixClientId = "auth:oidc-client:id:";
private const string CacheKeyPrefixClientSlug = "auth:oidc-client:slug:";
private const string CacheKeyPrefixAuthCode = "auth:oidc-code:";
private const string CodeChallengeMethodS256 = "S256";
private const string CodeChallengeMethodPlain = "PLAIN";
public async Task<CustomApp?> FindClientByIdAsync(Guid clientId)
{
var cacheKey = $"{CacheKeyPrefixClientId}{clientId}";
var (found, cachedApp) = await cache.GetAsyncWithStatus<CustomApp>(cacheKey);
if (found && cachedApp != null)
{
return cachedApp;
}
var resp = await customApps.GetCustomAppAsync(new GetCustomAppRequest { Id = clientId.ToString() });
if (resp.App != null)
{
await cache.SetAsync(cacheKey, resp.App, TimeSpan.FromMinutes(5));
}
return resp.App ?? null;
}
public async Task<CustomApp?> FindClientBySlugAsync(string slug)
{
var cacheKey = $"{CacheKeyPrefixClientSlug}{slug}";
var (found, cachedApp) = await cache.GetAsyncWithStatus<CustomApp>(cacheKey);
if (found && cachedApp != null)
{
return cachedApp;
}
var resp = await customApps.GetCustomAppAsync(new GetCustomAppRequest { Slug = slug });
if (resp.App != null)
{
await cache.SetAsync(cacheKey, resp.App, TimeSpan.FromMinutes(5));
}
return resp.App ?? null;
}
public async Task<SnAuthSession?> FindValidSessionAsync(Guid accountId, Guid clientId, bool withAccount = false)
private async Task<SnAuthSession?> FindValidSessionAsync(Guid accountId, Guid clientId, bool withAccount = false)
{
var now = SystemClock.Instance.GetCurrentInstant();
@@ -74,70 +102,79 @@ public class OidcProviderService(
return resp.Valid;
}
private static bool IsWildcardRedirectUriMatch(string allowedUri, string redirectUri)
{
if (string.IsNullOrEmpty(allowedUri) || string.IsNullOrEmpty(redirectUri))
return false;
// Check if it's an exact match
if (string.Equals(allowedUri, redirectUri, StringComparison.Ordinal))
return true;
// Quick check for wildcard patterns
if (!allowedUri.Contains('*'))
return false;
// Parse URIs once
Uri? allowedUriObj, redirectUriObj;
try
{
allowedUriObj = new Uri(allowedUri);
redirectUriObj = new Uri(redirectUri);
}
catch (UriFormatException)
{
return false;
}
// Check scheme and port matches
if (allowedUriObj.Scheme != redirectUriObj.Scheme || allowedUriObj.Port != redirectUriObj.Port)
{
return false;
}
var allowedHost = allowedUriObj.Host;
var redirectHost = redirectUriObj.Host;
// Handle wildcard domain patterns like *.example.com
if (allowedHost.StartsWith("*."))
{
var baseDomain = allowedHost[2..]; // Remove "*."
if (redirectHost == baseDomain || redirectHost.EndsWith("." + baseDomain))
{
// Check path match
var allowedPath = allowedUriObj.AbsolutePath.TrimEnd('/');
var redirectPath = redirectUriObj.AbsolutePath.TrimEnd('/');
// If allowed path is empty, any path is allowed
// If allowed path is specified, redirect path must start with it
return string.IsNullOrEmpty(allowedPath) ||
redirectPath.StartsWith(allowedPath, StringComparison.OrdinalIgnoreCase);
}
}
return false;
}
public async Task<bool> ValidateRedirectUriAsync(Guid clientId, string redirectUri)
{
if (string.IsNullOrEmpty(redirectUri))
return false;
var client = await FindClientByIdAsync(clientId);
if (client?.Status != Shared.Proto.CustomAppStatus.Production)
return true;
if (client?.OauthConfig?.RedirectUris == null)
var redirectUris = client?.OauthConfig?.RedirectUris;
if (redirectUris == null || redirectUris.Count == 0)
return false;
// Check if the redirect URI matches any of the allowed URIs
// For exact match
if (client.OauthConfig.RedirectUris.Contains(redirectUri))
return true;
// Check for wildcard matches (e.g., https://*.example.com/*)
foreach (var allowedUri in client.OauthConfig.RedirectUris)
// Check each allowed URI for a match
foreach (var allowedUri in redirectUris)
{
if (string.IsNullOrEmpty(allowedUri))
continue;
// Handle wildcard in domain
if (allowedUri.Contains("*.") && allowedUri.StartsWith("http"))
if (IsWildcardRedirectUriMatch(allowedUri, redirectUri))
{
try
{
var allowedUriObj = new Uri(allowedUri);
var redirectUriObj = new Uri(redirectUri);
if (allowedUriObj.Scheme != redirectUriObj.Scheme ||
allowedUriObj.Port != redirectUriObj.Port)
{
continue;
}
// Check if the domain matches the wildcard pattern
var allowedDomain = allowedUriObj.Host;
var redirectDomain = redirectUriObj.Host;
if (allowedDomain.StartsWith("*."))
{
var baseDomain = allowedDomain[2..]; // Remove the "*." prefix
if (redirectDomain == baseDomain || redirectDomain.EndsWith($".{baseDomain}"))
{
// Check path
var allowedPath = allowedUriObj.AbsolutePath.TrimEnd('/');
var redirectPath = redirectUriObj.AbsolutePath.TrimEnd('/');
if (string.IsNullOrEmpty(allowedPath) ||
redirectPath.StartsWith(allowedPath, StringComparison.OrdinalIgnoreCase))
{
return true;
}
}
}
}
catch (UriFormatException)
{
// Invalid URI format in allowed URIs, skip
continue;
}
return true;
}
}
@@ -219,6 +256,53 @@ public class OidcProviderService(
return tokenHandler.WriteToken(token);
}
private async Task<(SnAuthSession session, string? nonce, List<string>? scopes)> HandleAuthorizationCodeFlowAsync(
string authorizationCode,
Guid clientId,
string? redirectUri,
string? codeVerifier
)
{
var authCode = await ValidateAuthorizationCodeAsync(authorizationCode, clientId, redirectUri, codeVerifier);
if (authCode == null)
throw new InvalidOperationException("Invalid authorization code");
// Load the session for the user
var existingSession = await FindValidSessionAsync(authCode.AccountId, clientId, withAccount: true);
SnAuthSession session;
if (existingSession == null)
{
var account = await db.Accounts
.Where(a => a.Id == authCode.AccountId)
.Include(a => a.Profile)
.Include(a => a.Contacts)
.FirstOrDefaultAsync();
if (account == null) throw new InvalidOperationException("Account not found");
session = await auth.CreateSessionForOidcAsync(account, SystemClock.Instance.GetCurrentInstant(), clientId);
session.Account = account;
}
else
{
session = existingSession;
}
return (session, authCode.Nonce, authCode.Scopes);
}
private async Task<(SnAuthSession session, string? nonce, List<string>? scopes)> HandleRefreshTokenFlowAsync(Guid sessionId)
{
var session = await FindSessionByIdAsync(sessionId) ??
throw new InvalidOperationException("Session not found");
// Verify the session is still valid
var now = SystemClock.Instance.GetCurrentInstant();
if (session.ExpiredAt.HasValue && session.ExpiredAt < now)
throw new InvalidOperationException("Session has expired");
return (session, null, null);
}
public async Task<TokenResponse> GenerateTokenResponseAsync(
Guid clientId,
string? authorizationCode = null,
@@ -227,58 +311,18 @@ public class OidcProviderService(
Guid? sessionId = null
)
{
if (clientId == Guid.Empty) throw new ArgumentException("Client ID cannot be empty", nameof(clientId));
var client = await FindClientByIdAsync(clientId) ?? throw new InvalidOperationException("Client not found");
SnAuthSession session;
var (session, nonce, scopes) = authorizationCode != null
? await HandleAuthorizationCodeFlowAsync(authorizationCode, clientId, redirectUri, codeVerifier)
: sessionId.HasValue
? await HandleRefreshTokenFlowAsync(sessionId.Value)
: throw new InvalidOperationException("Either authorization code or session ID must be provided");
var clock = SystemClock.Instance;
var now = clock.GetCurrentInstant();
string? nonce = null;
List<string>? scopes = null;
if (authorizationCode != null)
{
// Authorization code flow
var authCode = await ValidateAuthorizationCodeAsync(authorizationCode, clientId, redirectUri, codeVerifier);
if (authCode == null)
throw new InvalidOperationException("Invalid authorization code");
// Load the session for the user
var existingSession = await FindValidSessionAsync(authCode.AccountId, clientId, withAccount: true);
if (existingSession is null)
{
var account = await db.Accounts
.Where(a => a.Id == authCode.AccountId)
.Include(a => a.Profile)
.Include(a => a.Contacts)
.FirstOrDefaultAsync();
if (account is null) throw new InvalidOperationException("Account not found");
session = await auth.CreateSessionForOidcAsync(account, clock.GetCurrentInstant(), clientId);
session.Account = account;
}
else
{
session = existingSession;
}
scopes = authCode.Scopes;
nonce = authCode.Nonce;
}
else if (sessionId.HasValue)
{
// Refresh token flow
session = await FindSessionByIdAsync(sessionId.Value) ??
throw new InvalidOperationException("Session not found");
// Verify the session is still valid
if (session.ExpiredAt < now)
throw new InvalidOperationException("Session has expired");
}
else
{
throw new InvalidOperationException("Either authorization code or session ID must be provided");
}
var expiresIn = (int)_options.AccessTokenLifetime.TotalSeconds;
var expiresAt = now.Plus(Duration.FromSeconds(expiresIn));
@@ -415,7 +459,7 @@ public class OidcProviderService(
};
// Store the code with its metadata in the cache
var cacheKey = $"auth:oidc-code:{code}";
var cacheKey = $"{CacheKeyPrefixAuthCode}{code}";
await cache.SetAsync(cacheKey, authCodeInfo, _options.AuthorizationCodeLifetime);
logger.LogInformation("Generated authorization code for client {ClientId} and user {UserId}", clientId, userId);
@@ -429,7 +473,7 @@ public class OidcProviderService(
string? codeVerifier = null
)
{
var cacheKey = $"auth:oidc-code:{code}";
var cacheKey = $"{CacheKeyPrefixAuthCode}{code}";
var (found, authCode) = await cache.GetAsyncWithStatus<AuthorizationCodeInfo>(cacheKey);
if (!found || authCode == null)
@@ -465,8 +509,8 @@ public class OidcProviderService(
var isValid = authCode.CodeChallengeMethod?.ToUpperInvariant() switch
{
"S256" => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, "S256"),
"PLAIN" => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, "PLAIN"),
CodeChallengeMethodS256 => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, CodeChallengeMethodS256),
CodeChallengeMethodPlain => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, CodeChallengeMethodPlain),
_ => false // Unsupported code challenge method
};
@@ -504,19 +548,12 @@ public class OidcProviderService(
{
if (string.IsNullOrEmpty(codeVerifier)) return false;
if (method == "S256")
{
using var sha256 = SHA256.Create();
var hash = sha256.ComputeHash(Encoding.UTF8.GetBytes(codeVerifier));
var base64 = Base64UrlEncoder.Encode(hash);
return string.Equals(base64, codeChallenge, StringComparison.Ordinal);
}
if (method != CodeChallengeMethodS256)
return method == CodeChallengeMethodPlain &&
string.Equals(codeVerifier, codeChallenge, StringComparison.Ordinal);
var hash = SHA256.HashData(Encoding.UTF8.GetBytes(codeVerifier));
var base64 = Base64UrlEncoder.Encode(hash);
if (method == "PLAIN")
{
return string.Equals(codeVerifier, codeChallenge, StringComparison.Ordinal);
}
return false;
return string.Equals(base64, codeChallenge, StringComparison.Ordinal);
}
}

View File

@@ -17,6 +17,11 @@ public class AfdianOidcService(
protected override string DiscoveryEndpoint => ""; // Afdian doesn't have a standard OIDC discovery endpoint
protected override string ConfigSectionName => "Afdian";
public override Task<string> GetAuthorizationUrlAsync(string state, string nonce)
{
return Task.FromResult(GetAuthorizationUrl(state, nonce));
}
public override string GetAuthorizationUrl(string state, string nonce)
{
var config = GetProviderConfig();
@@ -90,4 +95,4 @@ public class AfdianOidcService(
throw;
}
}
}
}

View File

@@ -27,6 +27,30 @@ public class AppleOidcService(
protected override string DiscoveryEndpoint => "https://appleid.apple.com/.well-known/openid-configuration";
protected override string ConfigSectionName => "Apple";
public override async Task<string> GetAuthorizationUrlAsync(string state, string nonce)
{
var config = GetProviderConfig();
var discoveryDocument = await GetDiscoveryDocumentAsync();
if (discoveryDocument?.AuthorizationEndpoint == null)
{
throw new InvalidOperationException("Authorization endpoint not found in discovery document");
}
var queryParams = BuildAuthorizationParameters(
config.ClientId,
config.RedirectUri,
"name email",
"code id_token",
state,
nonce,
"form_post"
);
var queryString = string.Join("&", queryParams.Select(p => $"{p.Key}={Uri.EscapeDataString(p.Value)}"));
return $"{discoveryDocument.AuthorizationEndpoint}?{queryString}";
}
public override string GetAuthorizationUrl(string state, string nonce)
{
var config = GetProviderConfig();
@@ -276,4 +300,4 @@ public class AppleKey
return Convert.FromBase64String(output);
}
}
}

View File

@@ -16,6 +16,11 @@ public class DiscordOidcService(
protected override string DiscoveryEndpoint => ""; // Discord doesn't have a standard OIDC discovery endpoint
protected override string ConfigSectionName => "Discord";
public override Task<string> GetAuthorizationUrlAsync(string state, string nonce)
{
return Task.FromResult(GetAuthorizationUrl(state, nonce));
}
public override string GetAuthorizationUrl(string state, string nonce)
{
var config = GetProviderConfig();
@@ -111,4 +116,4 @@ public class DiscordOidcService(
Provider = ProviderName
};
}
}
}

View File

@@ -16,6 +16,11 @@ public class GitHubOidcService(
protected override string DiscoveryEndpoint => ""; // GitHub doesn't have a standard OIDC discovery endpoint
protected override string ConfigSectionName => "GitHub";
public override Task<string> GetAuthorizationUrlAsync(string state, string nonce)
{
return Task.FromResult(GetAuthorizationUrl(state, nonce));
}
public override string GetAuthorizationUrl(string state, string nonce)
{
var config = GetProviderConfig();
@@ -123,4 +128,4 @@ public class GitHubOidcService(
public bool Primary { get; set; }
public bool Verified { get; set; }
}
}
}

View File

@@ -19,25 +19,36 @@ public class GoogleOidcService(
protected override string DiscoveryEndpoint => "https://accounts.google.com/.well-known/openid-configuration";
protected override string ConfigSectionName => "Google";
public override string GetAuthorizationUrl(string state, string nonce)
public override async Task<string> GetAuthorizationUrlAsync(string state, string nonce)
{
var config = GetProviderConfig();
var discoveryDocument = GetDiscoveryDocumentAsync().GetAwaiter().GetResult();
var discoveryDocument = await GetDiscoveryDocumentAsync();
if (discoveryDocument?.AuthorizationEndpoint == null)
{
throw new InvalidOperationException("Authorization endpoint not found in discovery document");
}
var queryParams = new Dictionary<string, string>
{
{ "client_id", config.ClientId },
{ "redirect_uri", config.RedirectUri },
{ "response_type", "code" },
{ "scope", "openid email profile" },
{ "state", state }, // No '|codeVerifier' appended anymore
{ "nonce", nonce }
};
// Generate PKCE code verifier and challenge for enhanced security
var codeVerifier = GenerateCodeVerifier();
var codeChallenge = GenerateCodeChallenge(codeVerifier);
var queryParams = BuildAuthorizationParameters(
config.ClientId,
config.RedirectUri,
"openid email profile",
"code",
state,
nonce
);
// Add PKCE parameters
queryParams["code_challenge"] = codeChallenge;
queryParams["code_challenge_method"] = "S256";
// Store code verifier in cache for later token exchange
var codeVerifierKey = $"pkce:{state}";
await cache.SetAsync(codeVerifierKey, codeVerifier, TimeSpan.FromMinutes(15));
var queryString = string.Join("&", queryParams.Select(p => $"{p.Key}={Uri.EscapeDataString(p.Value)}"));
return $"{discoveryDocument.AuthorizationEndpoint}?{queryString}";
@@ -45,89 +56,34 @@ public class GoogleOidcService(
public override async Task<OidcUserInfo> ProcessCallbackAsync(OidcCallbackData callbackData)
{
// No need to split or parse code verifier from state
var state = callbackData.State ?? "";
callbackData.State = state; // Keep the original state if needed
// Exchange the code for tokens
// Pass null or omit the parameter for codeVerifier as PKCE is removed
var tokenResponse = await ExchangeCodeForTokensAsync(callbackData.Code, null);
if (tokenResponse?.IdToken == null)
// Retrieve PKCE code verifier from cache
var codeVerifierKey = $"pkce:{state}";
var (found, codeVerifier) = await cache.GetAsyncWithStatus<string>(codeVerifierKey);
if (!found || string.IsNullOrEmpty(codeVerifier))
{
throw new InvalidOperationException("Failed to obtain ID token from Google");
throw new InvalidOperationException("PKCE code verifier not found or expired");
}
// Validate the ID token
var userInfo = await ValidateTokenAsync(tokenResponse.IdToken);
// Remove the code verifier from cache to prevent replay attacks
await cache.RemoveAsync(codeVerifierKey);
// Set tokens on the user info
userInfo.AccessToken = tokenResponse.AccessToken;
userInfo.RefreshToken = tokenResponse.RefreshToken;
// Try to fetch additional profile data if userinfo endpoint is available
try
// Exchange the code for tokens using PKCE
var tokenResponse = await ExchangeCodeForTokensAsync(callbackData.Code, codeVerifier);
if (tokenResponse == null)
{
var discoveryDocument = await GetDiscoveryDocumentAsync();
if (discoveryDocument?.UserinfoEndpoint != null && !string.IsNullOrEmpty(tokenResponse.AccessToken))
{
var client = _httpClientFactory.CreateClient();
client.DefaultRequestHeaders.Authorization =
new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", tokenResponse.AccessToken);
var userInfoResponse =
await client.GetFromJsonAsync<Dictionary<string, object>>(discoveryDocument.UserinfoEndpoint);
if (userInfoResponse != null)
{
if (userInfoResponse.TryGetValue("picture", out var picture) && picture != null)
{
userInfo.ProfilePictureUrl = picture.ToString();
}
}
}
}
catch
{
// Ignore errors when fetching additional profile data
throw new InvalidOperationException("Failed to exchange code for tokens");
}
return userInfo;
}
private async Task<OidcUserInfo> ValidateTokenAsync(string idToken)
{
// Use the strategy pattern to retrieve user info
var discoveryDocument = await GetDiscoveryDocumentAsync();
if (discoveryDocument?.JwksUri == null)
{
throw new InvalidOperationException("JWKS URI not found in discovery document");
}
var config = GetProviderConfig();
var strategy = new IdTokenValidationStrategy(_httpClientFactory);
var client = _httpClientFactory.CreateClient();
var jwksResponse = await client.GetFromJsonAsync<JsonWebKeySet>(discoveryDocument.JwksUri);
if (jwksResponse == null)
{
throw new InvalidOperationException("Failed to retrieve JWKS from Google");
}
var handler = new JwtSecurityTokenHandler();
var jwtToken = handler.ReadJwtToken(idToken);
var kid = jwtToken.Header.Kid;
var signingKey = jwksResponse.Keys.FirstOrDefault(k => k.Kid == kid);
if (signingKey == null)
{
throw new SecurityTokenValidationException("Unable to find matching key in Google's JWKS");
}
var validationParameters = new TokenValidationParameters
{
ValidateIssuer = true,
ValidIssuer = "https://accounts.google.com",
ValidateAudience = true,
ValidAudience = GetProviderConfig().ClientId,
ValidateLifetime = true,
IssuerSigningKey = signingKey
};
return ValidateAndExtractIdToken(idToken, validationParameters);
return await strategy.GetUserInfoAsync(tokenResponse, discoveryDocument, config.ClientId, ProviderName);
}
}
}

View File

@@ -20,6 +20,27 @@ public class MicrosoftOidcService(
protected override string ConfigSectionName => "Microsoft";
public override async Task<string> GetAuthorizationUrlAsync(string state, string nonce)
{
var config = GetProviderConfig();
var discoveryDocument = await GetDiscoveryDocumentAsync();
if (discoveryDocument?.AuthorizationEndpoint == null)
throw new InvalidOperationException("Authorization endpoint not found in discovery document.");
var queryParams = BuildAuthorizationParameters(
config.ClientId,
config.RedirectUri,
"openid profile email",
"code",
state,
nonce
);
var queryString = string.Join("&", queryParams.Select(p => $"{p.Key}={Uri.EscapeDataString(p.Value)}"));
return $"{discoveryDocument.AuthorizationEndpoint}?{queryString}";
}
public override string GetAuthorizationUrl(string state, string nonce)
{
var config = GetProviderConfig();
@@ -120,4 +141,4 @@ public class MicrosoftOidcService(
Provider = ProviderName
};
}
}
}

View File

@@ -43,7 +43,7 @@ public class OidcController(
await cache.SetAsync($"{StateCachePrefix}{state}", oidcState, StateExpiration);
// The state parameter sent to the provider is the GUID key for the cache.
var authUrl = oidcService.GetAuthorizationUrl(state, nonce);
var authUrl = await oidcService.GetAuthorizationUrlAsync(state, nonce);
return Redirect(authUrl);
}
else // Otherwise, proceed with the login / registration flow
@@ -54,7 +54,7 @@ public class OidcController(
// Create login state with return URL and device ID
var oidcState = OidcState.ForLogin(returnUrl ?? "/", deviceId);
await cache.SetAsync($"{StateCachePrefix}{state}", oidcState, StateExpiration);
var authUrl = oidcService.GetAuthorizationUrl(state, nonce);
var authUrl = await oidcService.GetAuthorizationUrlAsync(state, nonce);
return Redirect(authUrl);
}
}
@@ -194,4 +194,4 @@ public class OidcController(
return newAccount;
}
}
}

View File

@@ -1,4 +1,7 @@
using System;
using System.IdentityModel.Tokens.Jwt;
using System.Security.Cryptography;
using System.Text;
using System.Text.Json.Serialization;
using DysonNetwork.Shared.Cache;
using DysonNetwork.Shared.Models;
@@ -39,9 +42,40 @@ public abstract class OidcService(
protected abstract string ConfigSectionName { get; }
/// <summary>
/// Gets the authorization URL for initiating the authentication flow
/// Gets the authorization URL for initiating the authentication flow (async)
/// </summary>
public abstract string GetAuthorizationUrl(string state, string nonce);
public abstract Task<string> GetAuthorizationUrlAsync(string state, string nonce);
/// <summary>
/// Gets the authorization URL for initiating the authentication flow (sync for backward compatibility)
/// </summary>
public virtual string GetAuthorizationUrl(string state, string nonce)
{
return GetAuthorizationUrlAsync(state, nonce).GetAwaiter().GetResult();
}
/// <summary>
/// Builds common authorization URL query parameters
/// </summary>
protected Dictionary<string, string> BuildAuthorizationParameters(string clientId, string redirectUri, string scope, string responseType, string state, string nonce, string? responseMode = null)
{
var parameters = new Dictionary<string, string>
{
["client_id"] = clientId,
["redirect_uri"] = redirectUri,
["response_type"] = responseType,
["scope"] = scope,
["state"] = state,
["nonce"] = nonce
};
if (!string.IsNullOrEmpty(responseMode))
{
parameters["response_mode"] = responseMode;
}
return parameters;
}
/// <summary>
/// Process the callback from the OIDC provider
@@ -61,6 +95,38 @@ public abstract class OidcService(
};
}
/// <summary>
/// Generates a cryptographically secure random code verifier for PKCE
/// </summary>
protected static string GenerateCodeVerifier()
{
// Generate a 32-byte (256-bit) random byte array
var randomBytes = new byte[32];
RandomNumberGenerator.Fill(randomBytes);
// Convert to URL-safe base64 (no padding)
return Convert.ToBase64String(randomBytes)
.Replace("+", "-")
.Replace("/", "_")
.TrimEnd('=');
}
/// <summary>
/// Generates the code challenge from a code verifier using S256 method
/// </summary>
protected static string GenerateCodeChallenge(string codeVerifier)
{
using var sha256 = SHA256.Create();
var bytes = Encoding.UTF8.GetBytes(codeVerifier);
var hash = sha256.ComputeHash(bytes);
// Convert to URL-safe base64 (no padding)
return Convert.ToBase64String(hash)
.Replace("+", "-")
.Replace("/", "_")
.TrimEnd('=');
}
/// <summary>
/// Retrieves the OpenID Connect discovery document
/// </summary>

View File

@@ -0,0 +1,236 @@
using System.IdentityModel.Tokens.Jwt;
using System.Security.Cryptography;
using System.Text.Json;
using Microsoft.EntityFrameworkCore;
using Microsoft.IdentityModel.Tokens;
namespace DysonNetwork.Pass.Auth.OpenId;
/// <summary>
/// Defines how to retrieve user information from an OIDC provider
/// </summary>
public interface IUserInfoStrategy
{
/// <summary>
/// Retrieves user information using the provided token response and discovery document
/// </summary>
Task<OidcUserInfo> GetUserInfoAsync(OidcTokenResponse tokenResponse, OidcDiscoveryDocument? discoveryDocument,
string clientId, string providerName);
}
/// <summary>
/// Strategy for validating and extracting user info from ID tokens (Google, Apple)
/// </summary>
public class IdTokenValidationStrategy : IUserInfoStrategy
{
private readonly IHttpClientFactory _httpClientFactory;
public IdTokenValidationStrategy(IHttpClientFactory httpClientFactory)
{
_httpClientFactory = httpClientFactory;
}
public async Task<OidcUserInfo> GetUserInfoAsync(OidcTokenResponse tokenResponse, OidcDiscoveryDocument? discoveryDocument,
string clientId, string providerName)
{
if (string.IsNullOrEmpty(tokenResponse.IdToken))
throw new InvalidOperationException("ID token not found in response");
// Determine issuer and validation parameters based on provider
var (issuer, jwksUri) = providerName.ToLower() switch
{
"google" => ("https://accounts.google.com",
discoveryDocument?.JwksUri ?? "https://www.googleapis.com/oauth2/v3/certs"),
"apple" => ("https://appleid.apple.com",
"https://appleid.apple.com/auth/keys"),
_ => throw new NotSupportedException($"ID token validation not supported for provider: {providerName}")
};
// Get and validate the token
var jwksJson = await GetJwksAsync(jwksUri);
var userInfo = await ValidateIdTokenAsync(tokenResponse.IdToken, clientId, issuer, jwksJson, providerName);
// Set tokens on the user info
userInfo.AccessToken = tokenResponse.AccessToken;
userInfo.RefreshToken = tokenResponse.RefreshToken;
// For Google, try to fetch additional profile data
if (providerName.ToLower() == "google" && discoveryDocument?.UserinfoEndpoint != null
&& !string.IsNullOrEmpty(tokenResponse.AccessToken))
{
await FetchAdditionalProfileDataAsync(userInfo, discoveryDocument.UserinfoEndpoint,
tokenResponse.AccessToken);
}
// For Apple, parse additional user data if provided
if (providerName.ToLower() == "apple")
{
// Apple-specific handling would go here
}
return userInfo;
}
private async Task<string> GetJwksAsync(string jwksUri)
{
var client = _httpClientFactory.CreateClient();
var response = await client.GetAsync(jwksUri);
response.EnsureSuccessStatusCode();
return await response.Content.ReadAsStringAsync();
}
private async Task<OidcUserInfo> ValidateIdTokenAsync(string idToken, string clientId, string issuer,
string jwksJson, string providerName)
{
var jwks = JsonSerializer.Deserialize<JsonWebKeySet>(jwksJson)
?? throw new InvalidOperationException("Failed to parse JWKS");
var handler = new JwtSecurityTokenHandler();
var jwtToken = handler.ReadJwtToken(idToken);
var kid = jwtToken.Header.Kid;
var signingKey = jwks.Keys.FirstOrDefault(k => k.Kid == kid)
?? throw new SecurityTokenValidationException($"Unable to find key {kid} in JWKS");
var validationParameters = new TokenValidationParameters
{
ValidateIssuer = true,
ValidIssuer = issuer,
ValidateAudience = true,
ValidAudience = clientId,
ValidateLifetime = true,
IssuerSigningKey = signingKey
};
handler.ValidateToken(idToken, validationParameters, out _);
return ExtractUserInfoFromJwt(jwtToken, providerName);
}
private OidcUserInfo ExtractUserInfoFromJwt(JwtSecurityToken jwtToken, string providerName)
{
var userId = jwtToken.Claims.FirstOrDefault(c => c.Type == "sub")?.Value;
var email = jwtToken.Claims.FirstOrDefault(c => c.Type == "email")?.Value;
var emailVerified = jwtToken.Claims.FirstOrDefault(c => c.Type == "email_verified")?.Value == "true";
var name = jwtToken.Claims.FirstOrDefault(c => c.Type == "name")?.Value;
var givenName = jwtToken.Claims.FirstOrDefault(c => c.Type == "given_name")?.Value;
var familyName = jwtToken.Claims.FirstOrDefault(c => c.Type == "family_name")?.Value;
var preferredUsername = jwtToken.Claims.FirstOrDefault(c => c.Type == "preferred_username")?.Value;
var picture = jwtToken.Claims.FirstOrDefault(c => c.Type == "picture")?.Value;
// Determine preferred username - try different options
var username = preferredUsername;
if (string.IsNullOrEmpty(username))
{
// Fall back to email local part if no preferred username
username = !string.IsNullOrEmpty(email) ? email.Split('@')[0] : null;
}
return new OidcUserInfo
{
UserId = userId,
Email = email,
EmailVerified = emailVerified,
FirstName = givenName ?? "",
LastName = familyName ?? "",
DisplayName = name ?? $"{givenName} {familyName}".Trim(),
PreferredUsername = username ?? "",
ProfilePictureUrl = picture,
Provider = providerName
};
}
private async Task FetchAdditionalProfileDataAsync(OidcUserInfo userInfo, string userinfoEndpoint, string accessToken)
{
try
{
var client = _httpClientFactory.CreateClient();
client.DefaultRequestHeaders.Authorization =
new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", accessToken);
var userInfoResponse = await client.GetFromJsonAsync<Dictionary<string, object>>(userinfoEndpoint);
if (userInfoResponse != null)
{
if (userInfoResponse.TryGetValue("picture", out var picture) && picture != null)
{
userInfo.ProfilePictureUrl = picture.ToString();
}
}
}
catch
{
// Ignore errors when fetching additional profile data
}
}
}
/// <summary>
/// Strategy for fetching user info from OAuth 2.0 userinfo endpoints (Microsoft, Discord, GitHub)
/// </summary>
public class UserInfoEndpointStrategy : IUserInfoStrategy
{
private readonly IHttpClientFactory _httpClientFactory;
private readonly Func<JsonElement, OidcUserInfo> _parseUserInfo;
private readonly string? _userAgent;
public UserInfoEndpointStrategy(IHttpClientFactory httpClientFactory,
Func<JsonElement, OidcUserInfo> parseUserInfo, string? userAgent = null)
{
_httpClientFactory = httpClientFactory;
_parseUserInfo = parseUserInfo;
_userAgent = userAgent;
}
public async Task<OidcUserInfo> GetUserInfoAsync(OidcTokenResponse tokenResponse, OidcDiscoveryDocument? discoveryDocument,
string clientId, string providerName)
{
if (string.IsNullOrEmpty(tokenResponse.AccessToken) || string.IsNullOrEmpty(discoveryDocument?.UserinfoEndpoint))
throw new InvalidOperationException("Access token or userinfo endpoint missing");
var client = _httpClientFactory.CreateClient();
var request = new HttpRequestMessage(HttpMethod.Get, discoveryDocument.UserinfoEndpoint);
request.Headers.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", tokenResponse.AccessToken);
if (!string.IsNullOrEmpty(_userAgent))
request.Headers.Add("User-Agent", _userAgent);
var response = await client.SendAsync(request);
response.EnsureSuccessStatusCode();
var json = await response.Content.ReadAsStringAsync();
var userElement = JsonDocument.Parse(json).RootElement;
var userInfo = _parseUserInfo(userElement);
userInfo.AccessToken = tokenResponse.AccessToken;
userInfo.RefreshToken = tokenResponse.RefreshToken;
userInfo.Provider = providerName;
return userInfo;
}
}
/// <summary>
/// Strategy for extracting user info directly from token responses (Afdian)
/// </summary>
public class DirectTokenResponseStrategy : IUserInfoStrategy
{
public Task<OidcUserInfo> GetUserInfoAsync(OidcTokenResponse tokenResponse, OidcDiscoveryDocument? discoveryDocument,
string clientId, string providerName)
{
// Parse user info directly from token response data
// This would depend on how the specific provider returns user data
if (string.IsNullOrEmpty(tokenResponse.AccessToken))
throw new InvalidOperationException("Access token missing");
// For Afdian, the user data is embedded in the initial token response
// This strategy would need to know how to parse that specific format
var userInfo = new OidcUserInfo
{
AccessToken = tokenResponse.AccessToken,
RefreshToken = tokenResponse.RefreshToken,
Provider = providerName,
// Parse user data from token response content...
};
return Task.FromResult(userInfo);
}
}