⚡ Add cache to oidc discovery
This commit is contained in:
		| @@ -3,6 +3,7 @@ using System.Security.Cryptography; | ||||
| using System.Text; | ||||
| using System.Text.Json; | ||||
| using System.Text.Json.Serialization; | ||||
| using DysonNetwork.Sphere.Storage; | ||||
| using Microsoft.IdentityModel.Tokens; | ||||
|  | ||||
| namespace DysonNetwork.Sphere.Auth.OpenId; | ||||
| @@ -13,9 +14,10 @@ namespace DysonNetwork.Sphere.Auth.OpenId; | ||||
| public class AppleOidcService( | ||||
|     IConfiguration configuration, | ||||
|     IHttpClientFactory httpClientFactory, | ||||
|     AppDatabase db | ||||
|     AppDatabase db, | ||||
|     ICacheService cache | ||||
| ) | ||||
|     : OidcService(configuration, httpClientFactory, db) | ||||
|     : OidcService(configuration, httpClientFactory, db, cache) | ||||
| { | ||||
|     private readonly IConfiguration _configuration = configuration; | ||||
|     private readonly IHttpClientFactory _httpClientFactory = httpClientFactory; | ||||
|   | ||||
| @@ -1,15 +1,17 @@ | ||||
| using System.Net.Http.Json; | ||||
| using System.Text.Json; | ||||
| using DysonNetwork.Sphere.Storage; | ||||
|  | ||||
| namespace DysonNetwork.Sphere.Auth.OpenId; | ||||
|  | ||||
| public class DiscordOidcService : OidcService | ||||
| public class DiscordOidcService( | ||||
|     IConfiguration configuration, | ||||
|     IHttpClientFactory httpClientFactory, | ||||
|     AppDatabase db, | ||||
|     ICacheService cache | ||||
| ) | ||||
|     : OidcService(configuration, httpClientFactory, db, cache) | ||||
| { | ||||
|     public DiscordOidcService(IConfiguration configuration, IHttpClientFactory httpClientFactory, AppDatabase db) | ||||
|         : base(configuration, httpClientFactory, db) | ||||
|     { | ||||
|     } | ||||
|  | ||||
|     public override string ProviderName => "Discord"; | ||||
|     protected override string DiscoveryEndpoint => ""; // Discord doesn't have a standard OIDC discovery endpoint | ||||
|     protected override string ConfigSectionName => "Discord"; | ||||
| @@ -46,10 +48,11 @@ public class DiscordOidcService : OidcService | ||||
|         return userInfo; | ||||
|     } | ||||
|  | ||||
|     protected override async Task<OidcTokenResponse?> ExchangeCodeForTokensAsync(string code, string? codeVerifier = null) | ||||
|     protected override async Task<OidcTokenResponse?> ExchangeCodeForTokensAsync(string code, | ||||
|         string? codeVerifier = null) | ||||
|     { | ||||
|         var config = GetProviderConfig(); | ||||
|         var client = _httpClientFactory.CreateClient(); | ||||
|         var client = HttpClientFactory.CreateClient(); | ||||
|  | ||||
|         var content = new FormUrlEncodedContent(new Dictionary<string, string> | ||||
|         { | ||||
| @@ -68,7 +71,7 @@ public class DiscordOidcService : OidcService | ||||
|  | ||||
|     private async Task<OidcUserInfo> GetUserInfoAsync(string accessToken) | ||||
|     { | ||||
|         var client = _httpClientFactory.CreateClient(); | ||||
|         var client = HttpClientFactory.CreateClient(); | ||||
|         var request = new HttpRequestMessage(HttpMethod.Get, "https://discord.com/api/users/@me"); | ||||
|         request.Headers.Add("Authorization", $"Bearer {accessToken}"); | ||||
|  | ||||
| @@ -85,11 +88,16 @@ public class DiscordOidcService : OidcService | ||||
|         { | ||||
|             UserId = userId, | ||||
|             Email = (discordUser.TryGetProperty("email", out var emailElement) ? emailElement.GetString() : null) ?? "", | ||||
|             EmailVerified = discordUser.TryGetProperty("verified", out var verifiedElement) && verifiedElement.GetBoolean(), | ||||
|             DisplayName = (discordUser.TryGetProperty("global_name", out var globalNameElement) ? globalNameElement.GetString() : null) ?? "", | ||||
|             EmailVerified = discordUser.TryGetProperty("verified", out var verifiedElement) && | ||||
|                             verifiedElement.GetBoolean(), | ||||
|             DisplayName = (discordUser.TryGetProperty("global_name", out var globalNameElement) | ||||
|                 ? globalNameElement.GetString() | ||||
|                 : null) ?? "", | ||||
|             PreferredUsername = discordUser.GetProperty("username").GetString() ?? "", | ||||
|             ProfilePictureUrl = !string.IsNullOrEmpty(avatar) ? $"https://cdn.discordapp.com/avatars/{userId}/{avatar}.png" : "", | ||||
|             ProfilePictureUrl = !string.IsNullOrEmpty(avatar) | ||||
|                 ? $"https://cdn.discordapp.com/avatars/{userId}/{avatar}.png" | ||||
|                 : "", | ||||
|             Provider = ProviderName | ||||
|         }; | ||||
|     } | ||||
| } | ||||
| } | ||||
| @@ -1,15 +1,17 @@ | ||||
| using System.Net.Http.Json; | ||||
| using System.Text.Json; | ||||
| using DysonNetwork.Sphere.Storage; | ||||
|  | ||||
| namespace DysonNetwork.Sphere.Auth.OpenId; | ||||
|  | ||||
| public class GitHubOidcService : OidcService | ||||
| public class GitHubOidcService( | ||||
|     IConfiguration configuration, | ||||
|     IHttpClientFactory httpClientFactory, | ||||
|     AppDatabase db, | ||||
|     ICacheService cache | ||||
| ) | ||||
|     : OidcService(configuration, httpClientFactory, db, cache) | ||||
| { | ||||
|     public GitHubOidcService(IConfiguration configuration, IHttpClientFactory httpClientFactory, AppDatabase db) | ||||
|         : base(configuration, httpClientFactory, db) | ||||
|     { | ||||
|     } | ||||
|  | ||||
|     public override string ProviderName => "GitHub"; | ||||
|     protected override string DiscoveryEndpoint => ""; // GitHub doesn't have a standard OIDC discovery endpoint | ||||
|     protected override string ConfigSectionName => "GitHub"; | ||||
| @@ -45,10 +47,11 @@ public class GitHubOidcService : OidcService | ||||
|         return userInfo; | ||||
|     } | ||||
|  | ||||
|     protected override async Task<OidcTokenResponse?> ExchangeCodeForTokensAsync(string code, string? codeVerifier = null) | ||||
|     protected override async Task<OidcTokenResponse?> ExchangeCodeForTokensAsync(string code, | ||||
|         string? codeVerifier = null) | ||||
|     { | ||||
|         var config = GetProviderConfig(); | ||||
|         var client = _httpClientFactory.CreateClient(); | ||||
|         var client = HttpClientFactory.CreateClient(); | ||||
|  | ||||
|         var tokenRequest = new HttpRequestMessage(HttpMethod.Post, "https://github.com/login/oauth/access_token") | ||||
|         { | ||||
| @@ -70,7 +73,7 @@ public class GitHubOidcService : OidcService | ||||
|  | ||||
|     private async Task<OidcUserInfo> GetUserInfoAsync(string accessToken) | ||||
|     { | ||||
|         var client = _httpClientFactory.CreateClient(); | ||||
|         var client = HttpClientFactory.CreateClient(); | ||||
|         var request = new HttpRequestMessage(HttpMethod.Get, "https://api.github.com/user"); | ||||
|         request.Headers.Add("Authorization", $"Bearer {accessToken}"); | ||||
|         request.Headers.Add("User-Agent", "DysonNetwork.Sphere"); | ||||
| @@ -93,14 +96,16 @@ public class GitHubOidcService : OidcService | ||||
|             Email = email, | ||||
|             DisplayName = githubUser.TryGetProperty("name", out var nameElement) ? nameElement.GetString() ?? "" : "", | ||||
|             PreferredUsername = githubUser.GetProperty("login").GetString() ?? "", | ||||
|             ProfilePictureUrl = githubUser.TryGetProperty("avatar_url", out var avatarElement) ? avatarElement.GetString() ?? "" : "", | ||||
|             ProfilePictureUrl = githubUser.TryGetProperty("avatar_url", out var avatarElement) | ||||
|                 ? avatarElement.GetString() ?? "" | ||||
|                 : "", | ||||
|             Provider = ProviderName | ||||
|         }; | ||||
|     } | ||||
|  | ||||
|     private async Task<string?> GetPrimaryEmailAsync(string accessToken) | ||||
|     { | ||||
|         var client = _httpClientFactory.CreateClient(); | ||||
|         var client = HttpClientFactory.CreateClient(); | ||||
|         var request = new HttpRequestMessage(HttpMethod.Get, "https://api.github.com/user/emails"); | ||||
|         request.Headers.Add("Authorization", $"Bearer {accessToken}"); | ||||
|         request.Headers.Add("User-Agent", "DysonNetwork.Sphere"); | ||||
| @@ -118,4 +123,4 @@ public class GitHubOidcService : OidcService | ||||
|         public bool Primary { get; set; } | ||||
|         public bool Verified { get; set; } | ||||
|     } | ||||
| } | ||||
| } | ||||
| @@ -2,6 +2,7 @@ using System.IdentityModel.Tokens.Jwt; | ||||
| using System.Net.Http.Json; | ||||
| using System.Security.Cryptography; | ||||
| using System.Text; | ||||
| using DysonNetwork.Sphere.Storage; | ||||
| using Microsoft.IdentityModel.Tokens; | ||||
|  | ||||
| namespace DysonNetwork.Sphere.Auth.OpenId; | ||||
| @@ -12,12 +13,13 @@ namespace DysonNetwork.Sphere.Auth.OpenId; | ||||
| public class GoogleOidcService( | ||||
|     IConfiguration configuration, | ||||
|     IHttpClientFactory httpClientFactory, | ||||
|     AppDatabase db | ||||
|     AppDatabase db, | ||||
|     ICacheService cache | ||||
| ) | ||||
|     : OidcService(configuration, httpClientFactory, db) | ||||
|     : OidcService(configuration, httpClientFactory, db, cache) | ||||
| { | ||||
|     private readonly IHttpClientFactory _httpClientFactory = httpClientFactory; | ||||
|      | ||||
|  | ||||
|     public override string ProviderName => "google"; | ||||
|     protected override string DiscoveryEndpoint => "https://accounts.google.com/.well-known/openid-configuration"; | ||||
|     protected override string ConfigSectionName => "Google"; | ||||
| @@ -85,16 +87,17 @@ public class GoogleOidcService( | ||||
|         userInfo.RefreshToken = tokenResponse.RefreshToken; | ||||
|  | ||||
|         // Try to fetch additional profile data if userinfo endpoint is available | ||||
|         try  | ||||
|         try | ||||
|         { | ||||
|             var discoveryDocument = await GetDiscoveryDocumentAsync(); | ||||
|             if (discoveryDocument?.UserinfoEndpoint != null && !string.IsNullOrEmpty(tokenResponse.AccessToken)) | ||||
|             { | ||||
|                 var client = _httpClientFactory.CreateClient(); | ||||
|                 client.DefaultRequestHeaders.Authorization =  | ||||
|                 client.DefaultRequestHeaders.Authorization = | ||||
|                     new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", tokenResponse.AccessToken); | ||||
|  | ||||
|                 var userInfoResponse = await client.GetFromJsonAsync<Dictionary<string, object>>(discoveryDocument.UserinfoEndpoint); | ||||
|                 var userInfoResponse = | ||||
|                     await client.GetFromJsonAsync<Dictionary<string, object>>(discoveryDocument.UserinfoEndpoint); | ||||
|  | ||||
|                 if (userInfoResponse != null) | ||||
|                 { | ||||
|   | ||||
| @@ -1,18 +1,22 @@ | ||||
| using System.Net.Http.Json; | ||||
| using System.Text.Json; | ||||
| using DysonNetwork.Sphere.Storage; | ||||
|  | ||||
| namespace DysonNetwork.Sphere.Auth.OpenId; | ||||
|  | ||||
| public class MicrosoftOidcService : OidcService | ||||
| public class MicrosoftOidcService( | ||||
|     IConfiguration configuration, | ||||
|     IHttpClientFactory httpClientFactory, | ||||
|     AppDatabase db, | ||||
|     ICacheService cache | ||||
| ) | ||||
|     : OidcService(configuration, httpClientFactory, db, cache) | ||||
| { | ||||
|     public MicrosoftOidcService(IConfiguration configuration, IHttpClientFactory httpClientFactory, AppDatabase db) | ||||
|         : base(configuration, httpClientFactory, db) | ||||
|     { | ||||
|     } | ||||
|  | ||||
|     public override string ProviderName => "Microsoft"; | ||||
|  | ||||
|     protected override string DiscoveryEndpoint => _configuration[$"Oidc:{ConfigSectionName}:DiscoveryEndpoint"] ?? throw new InvalidOperationException("Microsoft OIDC discovery endpoint is not configured."); | ||||
|     protected override string DiscoveryEndpoint => Configuration[$"Oidc:{ConfigSectionName}:DiscoveryEndpoint"] ?? | ||||
|                                                    throw new InvalidOperationException( | ||||
|                                                        "Microsoft OIDC discovery endpoint is not configured."); | ||||
|  | ||||
|     protected override string ConfigSectionName => "Microsoft"; | ||||
|  | ||||
| @@ -54,7 +58,8 @@ public class MicrosoftOidcService : OidcService | ||||
|         return userInfo; | ||||
|     } | ||||
|  | ||||
|     protected override async Task<OidcTokenResponse?> ExchangeCodeForTokensAsync(string code, string? codeVerifier = null) | ||||
|     protected override async Task<OidcTokenResponse?> ExchangeCodeForTokensAsync(string code, | ||||
|         string? codeVerifier = null) | ||||
|     { | ||||
|         var config = GetProviderConfig(); | ||||
|         var discoveryDocument = await GetDiscoveryDocumentAsync(); | ||||
| @@ -63,7 +68,7 @@ public class MicrosoftOidcService : OidcService | ||||
|             throw new InvalidOperationException("Token endpoint not found in discovery document."); | ||||
|         } | ||||
|  | ||||
|         var client = _httpClientFactory.CreateClient(); | ||||
|         var client = HttpClientFactory.CreateClient(); | ||||
|  | ||||
|         var tokenRequest = new HttpRequestMessage(HttpMethod.Post, discoveryDocument.TokenEndpoint) | ||||
|         { | ||||
| @@ -90,7 +95,7 @@ public class MicrosoftOidcService : OidcService | ||||
|         if (discoveryDocument?.UserinfoEndpoint == null) | ||||
|             throw new InvalidOperationException("Userinfo endpoint not found in discovery document."); | ||||
|  | ||||
|         var client = _httpClientFactory.CreateClient(); | ||||
|         var client = HttpClientFactory.CreateClient(); | ||||
|         var request = new HttpRequestMessage(HttpMethod.Get, discoveryDocument.UserinfoEndpoint); | ||||
|         request.Headers.Add("Authorization", $"Bearer {accessToken}"); | ||||
|  | ||||
| @@ -104,10 +109,15 @@ public class MicrosoftOidcService : OidcService | ||||
|         { | ||||
|             UserId = microsoftUser.GetProperty("sub").GetString() ?? "", | ||||
|             Email = microsoftUser.TryGetProperty("email", out var emailElement) ? emailElement.GetString() : null, | ||||
|             DisplayName = microsoftUser.TryGetProperty("name", out var nameElement) ? nameElement.GetString() ?? "" : "", | ||||
|             PreferredUsername = microsoftUser.TryGetProperty("preferred_username", out var preferredUsernameElement) ? preferredUsernameElement.GetString() ?? "" : "", | ||||
|             ProfilePictureUrl = microsoftUser.TryGetProperty("picture", out var pictureElement) ? pictureElement.GetString() ?? "" : "", | ||||
|             DisplayName = | ||||
|                 microsoftUser.TryGetProperty("name", out var nameElement) ? nameElement.GetString() ?? "" : "", | ||||
|             PreferredUsername = microsoftUser.TryGetProperty("preferred_username", out var preferredUsernameElement) | ||||
|                 ? preferredUsernameElement.GetString() ?? "" | ||||
|                 : "", | ||||
|             ProfilePictureUrl = microsoftUser.TryGetProperty("picture", out var pictureElement) | ||||
|                 ? pictureElement.GetString() ?? "" | ||||
|                 : "", | ||||
|             Provider = ProviderName | ||||
|         }; | ||||
|     } | ||||
| } | ||||
| } | ||||
| @@ -2,6 +2,7 @@ using System.IdentityModel.Tokens.Jwt; | ||||
| using System.Net.Http.Json; | ||||
| using System.Text.Json.Serialization; | ||||
| using DysonNetwork.Sphere.Account; | ||||
| using DysonNetwork.Sphere.Storage; | ||||
| using Microsoft.EntityFrameworkCore; | ||||
| using Microsoft.IdentityModel.Tokens; | ||||
| using NodaTime; | ||||
| @@ -11,18 +12,16 @@ namespace DysonNetwork.Sphere.Auth.OpenId; | ||||
| /// <summary> | ||||
| /// Base service for OpenID Connect authentication providers | ||||
| /// </summary> | ||||
| public abstract class OidcService | ||||
| public abstract class OidcService( | ||||
|     IConfiguration configuration, | ||||
|     IHttpClientFactory httpClientFactory, | ||||
|     AppDatabase db, | ||||
|     ICacheService cache | ||||
| ) | ||||
| { | ||||
|     protected readonly IConfiguration _configuration; | ||||
|     protected readonly IHttpClientFactory _httpClientFactory; | ||||
|     protected readonly AppDatabase _db; | ||||
|  | ||||
|     protected OidcService(IConfiguration configuration, IHttpClientFactory httpClientFactory, AppDatabase db) | ||||
|     { | ||||
|         _configuration = configuration; | ||||
|         _httpClientFactory = httpClientFactory; | ||||
|         _db = db; | ||||
|     } | ||||
|     protected readonly IConfiguration Configuration = configuration; | ||||
|     protected readonly IHttpClientFactory HttpClientFactory = httpClientFactory; | ||||
|     protected readonly AppDatabase Db = db; | ||||
|  | ||||
|     /// <summary> | ||||
|     /// Gets the unique identifier for this provider | ||||
| @@ -56,9 +55,9 @@ public abstract class OidcService | ||||
|     { | ||||
|         return new ProviderConfiguration | ||||
|         { | ||||
|                         ClientId = _configuration[$"Oidc:{ConfigSectionName}:ClientId"] ?? "", | ||||
|                         ClientSecret = _configuration[$"Oidc:{ConfigSectionName}:ClientSecret"] ?? "", | ||||
|                         RedirectUri = _configuration["BaseUrl"] + "/auth/callback/" + ProviderName.ToLower() | ||||
|             ClientId = Configuration[$"Oidc:{ConfigSectionName}:ClientId"] ?? "", | ||||
|             ClientSecret = Configuration[$"Oidc:{ConfigSectionName}:ClientSecret"] ?? "", | ||||
|             RedirectUri = Configuration["BaseUrl"] + "/auth/callback/" + ProviderName.ToLower() | ||||
|         }; | ||||
|     } | ||||
|  | ||||
| @@ -67,10 +66,28 @@ public abstract class OidcService | ||||
|     /// </summary> | ||||
|     protected async Task<OidcDiscoveryDocument?> GetDiscoveryDocumentAsync() | ||||
|     { | ||||
|         var client = _httpClientFactory.CreateClient(); | ||||
|         // Construct a cache key unique to the current provider: | ||||
|         var cacheKey = $"oidc-discovery:{ProviderName}"; | ||||
|  | ||||
|         // Try getting the discovery document from cache first: | ||||
|         var (found, cachedDoc) = await cache.GetAsyncWithStatus<OidcDiscoveryDocument>(cacheKey); | ||||
|         if (found && cachedDoc != null) | ||||
|         { | ||||
|             return cachedDoc; | ||||
|         } | ||||
|  | ||||
|         // If it's not cached, fetch from the actual discovery endpoint: | ||||
|         var client = HttpClientFactory.CreateClient(); | ||||
|         var response = await client.GetAsync(DiscoveryEndpoint); | ||||
|         response.EnsureSuccessStatusCode(); | ||||
|         return await response.Content.ReadFromJsonAsync<OidcDiscoveryDocument>(); | ||||
|         var doc = await response.Content.ReadFromJsonAsync<OidcDiscoveryDocument>(); | ||||
|  | ||||
|         // Store the discovery document in the cache for a while (e.g., 15 minutes): | ||||
|         if (doc is not null) | ||||
|             await cache.SetAsync(cacheKey, doc, TimeSpan.FromMinutes(15)); | ||||
|  | ||||
|         return doc; | ||||
|  | ||||
|     } | ||||
|  | ||||
|     /// <summary> | ||||
| @@ -87,7 +104,7 @@ public abstract class OidcService | ||||
|             throw new InvalidOperationException("Token endpoint not found in discovery document"); | ||||
|         } | ||||
|  | ||||
|         var client = _httpClientFactory.CreateClient(); | ||||
|         var client = HttpClientFactory.CreateClient(); | ||||
|         var content = new FormUrlEncodedContent(BuildTokenRequestParameters(code, config, codeVerifier)); | ||||
|  | ||||
|         var response = await client.PostAsync(discoveryDocument.TokenEndpoint, content); | ||||
| @@ -178,7 +195,7 @@ public abstract class OidcService | ||||
|     ) | ||||
|     { | ||||
|         // Create or update the account connection | ||||
|                 var connection = await _db.AccountConnections | ||||
|         var connection = await Db.AccountConnections | ||||
|             .FirstOrDefaultAsync(c => c.Provider == ProviderName && | ||||
|                                       c.ProvidedIdentifier == userInfo.UserId && | ||||
|                                       c.AccountId == account.Id | ||||
| @@ -195,7 +212,7 @@ public abstract class OidcService | ||||
|                 LastUsedAt = SystemClock.Instance.GetCurrentInstant(), | ||||
|                 AccountId = account.Id | ||||
|             }; | ||||
|                         await _db.AccountConnections.AddAsync(connection); | ||||
|             await Db.AccountConnections.AddAsync(connection); | ||||
|         } | ||||
|  | ||||
|         // Create a challenge that's already completed | ||||
| @@ -215,7 +232,7 @@ public abstract class OidcService | ||||
|             UserAgent = request.Request.Headers.UserAgent, | ||||
|         }; | ||||
|  | ||||
|                 await _db.AuthChallenges.AddAsync(challenge); | ||||
|         await Db.AuthChallenges.AddAsync(challenge); | ||||
|  | ||||
|         // Create a session | ||||
|         var session = new Session | ||||
| @@ -226,8 +243,8 @@ public abstract class OidcService | ||||
|             Challenge = challenge | ||||
|         }; | ||||
|  | ||||
|                 await _db.AuthSessions.AddAsync(session); | ||||
|                 await _db.SaveChangesAsync(); | ||||
|         await Db.AuthSessions.AddAsync(session); | ||||
|         await Db.SaveChangesAsync(); | ||||
|  | ||||
|         return session; | ||||
|     } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user