Add cache to oidc discovery

This commit is contained in:
LittleSheep 2025-06-16 23:15:45 +08:00
parent 47caff569d
commit fe04b12561
6 changed files with 114 additions and 69 deletions

View File

@ -3,6 +3,7 @@ using System.Security.Cryptography;
using System.Text; using System.Text;
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using DysonNetwork.Sphere.Storage;
using Microsoft.IdentityModel.Tokens; using Microsoft.IdentityModel.Tokens;
namespace DysonNetwork.Sphere.Auth.OpenId; namespace DysonNetwork.Sphere.Auth.OpenId;
@ -13,9 +14,10 @@ namespace DysonNetwork.Sphere.Auth.OpenId;
public class AppleOidcService( public class AppleOidcService(
IConfiguration configuration, IConfiguration configuration,
IHttpClientFactory httpClientFactory, IHttpClientFactory httpClientFactory,
AppDatabase db AppDatabase db,
ICacheService cache
) )
: OidcService(configuration, httpClientFactory, db) : OidcService(configuration, httpClientFactory, db, cache)
{ {
private readonly IConfiguration _configuration = configuration; private readonly IConfiguration _configuration = configuration;
private readonly IHttpClientFactory _httpClientFactory = httpClientFactory; private readonly IHttpClientFactory _httpClientFactory = httpClientFactory;

View File

@ -1,15 +1,17 @@
using System.Net.Http.Json; using System.Net.Http.Json;
using System.Text.Json; using System.Text.Json;
using DysonNetwork.Sphere.Storage;
namespace DysonNetwork.Sphere.Auth.OpenId; namespace DysonNetwork.Sphere.Auth.OpenId;
public class DiscordOidcService : OidcService public class DiscordOidcService(
IConfiguration configuration,
IHttpClientFactory httpClientFactory,
AppDatabase db,
ICacheService cache
)
: OidcService(configuration, httpClientFactory, db, cache)
{ {
public DiscordOidcService(IConfiguration configuration, IHttpClientFactory httpClientFactory, AppDatabase db)
: base(configuration, httpClientFactory, db)
{
}
public override string ProviderName => "Discord"; public override string ProviderName => "Discord";
protected override string DiscoveryEndpoint => ""; // Discord doesn't have a standard OIDC discovery endpoint protected override string DiscoveryEndpoint => ""; // Discord doesn't have a standard OIDC discovery endpoint
protected override string ConfigSectionName => "Discord"; protected override string ConfigSectionName => "Discord";
@ -46,10 +48,11 @@ public class DiscordOidcService : OidcService
return userInfo; return userInfo;
} }
protected override async Task<OidcTokenResponse?> ExchangeCodeForTokensAsync(string code, string? codeVerifier = null) protected override async Task<OidcTokenResponse?> ExchangeCodeForTokensAsync(string code,
string? codeVerifier = null)
{ {
var config = GetProviderConfig(); var config = GetProviderConfig();
var client = _httpClientFactory.CreateClient(); var client = HttpClientFactory.CreateClient();
var content = new FormUrlEncodedContent(new Dictionary<string, string> var content = new FormUrlEncodedContent(new Dictionary<string, string>
{ {
@ -68,7 +71,7 @@ public class DiscordOidcService : OidcService
private async Task<OidcUserInfo> GetUserInfoAsync(string accessToken) private async Task<OidcUserInfo> GetUserInfoAsync(string accessToken)
{ {
var client = _httpClientFactory.CreateClient(); var client = HttpClientFactory.CreateClient();
var request = new HttpRequestMessage(HttpMethod.Get, "https://discord.com/api/users/@me"); var request = new HttpRequestMessage(HttpMethod.Get, "https://discord.com/api/users/@me");
request.Headers.Add("Authorization", $"Bearer {accessToken}"); request.Headers.Add("Authorization", $"Bearer {accessToken}");
@ -85,11 +88,16 @@ public class DiscordOidcService : OidcService
{ {
UserId = userId, UserId = userId,
Email = (discordUser.TryGetProperty("email", out var emailElement) ? emailElement.GetString() : null) ?? "", Email = (discordUser.TryGetProperty("email", out var emailElement) ? emailElement.GetString() : null) ?? "",
EmailVerified = discordUser.TryGetProperty("verified", out var verifiedElement) && verifiedElement.GetBoolean(), EmailVerified = discordUser.TryGetProperty("verified", out var verifiedElement) &&
DisplayName = (discordUser.TryGetProperty("global_name", out var globalNameElement) ? globalNameElement.GetString() : null) ?? "", verifiedElement.GetBoolean(),
DisplayName = (discordUser.TryGetProperty("global_name", out var globalNameElement)
? globalNameElement.GetString()
: null) ?? "",
PreferredUsername = discordUser.GetProperty("username").GetString() ?? "", PreferredUsername = discordUser.GetProperty("username").GetString() ?? "",
ProfilePictureUrl = !string.IsNullOrEmpty(avatar) ? $"https://cdn.discordapp.com/avatars/{userId}/{avatar}.png" : "", ProfilePictureUrl = !string.IsNullOrEmpty(avatar)
? $"https://cdn.discordapp.com/avatars/{userId}/{avatar}.png"
: "",
Provider = ProviderName Provider = ProviderName
}; };
} }
} }

View File

@ -1,15 +1,17 @@
using System.Net.Http.Json; using System.Net.Http.Json;
using System.Text.Json; using System.Text.Json;
using DysonNetwork.Sphere.Storage;
namespace DysonNetwork.Sphere.Auth.OpenId; namespace DysonNetwork.Sphere.Auth.OpenId;
public class GitHubOidcService : OidcService public class GitHubOidcService(
IConfiguration configuration,
IHttpClientFactory httpClientFactory,
AppDatabase db,
ICacheService cache
)
: OidcService(configuration, httpClientFactory, db, cache)
{ {
public GitHubOidcService(IConfiguration configuration, IHttpClientFactory httpClientFactory, AppDatabase db)
: base(configuration, httpClientFactory, db)
{
}
public override string ProviderName => "GitHub"; public override string ProviderName => "GitHub";
protected override string DiscoveryEndpoint => ""; // GitHub doesn't have a standard OIDC discovery endpoint protected override string DiscoveryEndpoint => ""; // GitHub doesn't have a standard OIDC discovery endpoint
protected override string ConfigSectionName => "GitHub"; protected override string ConfigSectionName => "GitHub";
@ -45,10 +47,11 @@ public class GitHubOidcService : OidcService
return userInfo; return userInfo;
} }
protected override async Task<OidcTokenResponse?> ExchangeCodeForTokensAsync(string code, string? codeVerifier = null) protected override async Task<OidcTokenResponse?> ExchangeCodeForTokensAsync(string code,
string? codeVerifier = null)
{ {
var config = GetProviderConfig(); var config = GetProviderConfig();
var client = _httpClientFactory.CreateClient(); var client = HttpClientFactory.CreateClient();
var tokenRequest = new HttpRequestMessage(HttpMethod.Post, "https://github.com/login/oauth/access_token") var tokenRequest = new HttpRequestMessage(HttpMethod.Post, "https://github.com/login/oauth/access_token")
{ {
@ -70,7 +73,7 @@ public class GitHubOidcService : OidcService
private async Task<OidcUserInfo> GetUserInfoAsync(string accessToken) private async Task<OidcUserInfo> GetUserInfoAsync(string accessToken)
{ {
var client = _httpClientFactory.CreateClient(); var client = HttpClientFactory.CreateClient();
var request = new HttpRequestMessage(HttpMethod.Get, "https://api.github.com/user"); var request = new HttpRequestMessage(HttpMethod.Get, "https://api.github.com/user");
request.Headers.Add("Authorization", $"Bearer {accessToken}"); request.Headers.Add("Authorization", $"Bearer {accessToken}");
request.Headers.Add("User-Agent", "DysonNetwork.Sphere"); request.Headers.Add("User-Agent", "DysonNetwork.Sphere");
@ -93,14 +96,16 @@ public class GitHubOidcService : OidcService
Email = email, Email = email,
DisplayName = githubUser.TryGetProperty("name", out var nameElement) ? nameElement.GetString() ?? "" : "", DisplayName = githubUser.TryGetProperty("name", out var nameElement) ? nameElement.GetString() ?? "" : "",
PreferredUsername = githubUser.GetProperty("login").GetString() ?? "", PreferredUsername = githubUser.GetProperty("login").GetString() ?? "",
ProfilePictureUrl = githubUser.TryGetProperty("avatar_url", out var avatarElement) ? avatarElement.GetString() ?? "" : "", ProfilePictureUrl = githubUser.TryGetProperty("avatar_url", out var avatarElement)
? avatarElement.GetString() ?? ""
: "",
Provider = ProviderName Provider = ProviderName
}; };
} }
private async Task<string?> GetPrimaryEmailAsync(string accessToken) private async Task<string?> GetPrimaryEmailAsync(string accessToken)
{ {
var client = _httpClientFactory.CreateClient(); var client = HttpClientFactory.CreateClient();
var request = new HttpRequestMessage(HttpMethod.Get, "https://api.github.com/user/emails"); var request = new HttpRequestMessage(HttpMethod.Get, "https://api.github.com/user/emails");
request.Headers.Add("Authorization", $"Bearer {accessToken}"); request.Headers.Add("Authorization", $"Bearer {accessToken}");
request.Headers.Add("User-Agent", "DysonNetwork.Sphere"); request.Headers.Add("User-Agent", "DysonNetwork.Sphere");
@ -118,4 +123,4 @@ public class GitHubOidcService : OidcService
public bool Primary { get; set; } public bool Primary { get; set; }
public bool Verified { get; set; } public bool Verified { get; set; }
} }
} }

View File

@ -2,6 +2,7 @@ using System.IdentityModel.Tokens.Jwt;
using System.Net.Http.Json; using System.Net.Http.Json;
using System.Security.Cryptography; using System.Security.Cryptography;
using System.Text; using System.Text;
using DysonNetwork.Sphere.Storage;
using Microsoft.IdentityModel.Tokens; using Microsoft.IdentityModel.Tokens;
namespace DysonNetwork.Sphere.Auth.OpenId; namespace DysonNetwork.Sphere.Auth.OpenId;
@ -12,12 +13,13 @@ namespace DysonNetwork.Sphere.Auth.OpenId;
public class GoogleOidcService( public class GoogleOidcService(
IConfiguration configuration, IConfiguration configuration,
IHttpClientFactory httpClientFactory, IHttpClientFactory httpClientFactory,
AppDatabase db AppDatabase db,
ICacheService cache
) )
: OidcService(configuration, httpClientFactory, db) : OidcService(configuration, httpClientFactory, db, cache)
{ {
private readonly IHttpClientFactory _httpClientFactory = httpClientFactory; private readonly IHttpClientFactory _httpClientFactory = httpClientFactory;
public override string ProviderName => "google"; public override string ProviderName => "google";
protected override string DiscoveryEndpoint => "https://accounts.google.com/.well-known/openid-configuration"; protected override string DiscoveryEndpoint => "https://accounts.google.com/.well-known/openid-configuration";
protected override string ConfigSectionName => "Google"; protected override string ConfigSectionName => "Google";
@ -85,16 +87,17 @@ public class GoogleOidcService(
userInfo.RefreshToken = tokenResponse.RefreshToken; userInfo.RefreshToken = tokenResponse.RefreshToken;
// Try to fetch additional profile data if userinfo endpoint is available // Try to fetch additional profile data if userinfo endpoint is available
try try
{ {
var discoveryDocument = await GetDiscoveryDocumentAsync(); var discoveryDocument = await GetDiscoveryDocumentAsync();
if (discoveryDocument?.UserinfoEndpoint != null && !string.IsNullOrEmpty(tokenResponse.AccessToken)) if (discoveryDocument?.UserinfoEndpoint != null && !string.IsNullOrEmpty(tokenResponse.AccessToken))
{ {
var client = _httpClientFactory.CreateClient(); var client = _httpClientFactory.CreateClient();
client.DefaultRequestHeaders.Authorization = client.DefaultRequestHeaders.Authorization =
new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", tokenResponse.AccessToken); new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", tokenResponse.AccessToken);
var userInfoResponse = await client.GetFromJsonAsync<Dictionary<string, object>>(discoveryDocument.UserinfoEndpoint); var userInfoResponse =
await client.GetFromJsonAsync<Dictionary<string, object>>(discoveryDocument.UserinfoEndpoint);
if (userInfoResponse != null) if (userInfoResponse != null)
{ {

View File

@ -1,18 +1,22 @@
using System.Net.Http.Json; using System.Net.Http.Json;
using System.Text.Json; using System.Text.Json;
using DysonNetwork.Sphere.Storage;
namespace DysonNetwork.Sphere.Auth.OpenId; namespace DysonNetwork.Sphere.Auth.OpenId;
public class MicrosoftOidcService : OidcService public class MicrosoftOidcService(
IConfiguration configuration,
IHttpClientFactory httpClientFactory,
AppDatabase db,
ICacheService cache
)
: OidcService(configuration, httpClientFactory, db, cache)
{ {
public MicrosoftOidcService(IConfiguration configuration, IHttpClientFactory httpClientFactory, AppDatabase db)
: base(configuration, httpClientFactory, db)
{
}
public override string ProviderName => "Microsoft"; public override string ProviderName => "Microsoft";
protected override string DiscoveryEndpoint => _configuration[$"Oidc:{ConfigSectionName}:DiscoveryEndpoint"] ?? throw new InvalidOperationException("Microsoft OIDC discovery endpoint is not configured."); protected override string DiscoveryEndpoint => Configuration[$"Oidc:{ConfigSectionName}:DiscoveryEndpoint"] ??
throw new InvalidOperationException(
"Microsoft OIDC discovery endpoint is not configured.");
protected override string ConfigSectionName => "Microsoft"; protected override string ConfigSectionName => "Microsoft";
@ -54,7 +58,8 @@ public class MicrosoftOidcService : OidcService
return userInfo; return userInfo;
} }
protected override async Task<OidcTokenResponse?> ExchangeCodeForTokensAsync(string code, string? codeVerifier = null) protected override async Task<OidcTokenResponse?> ExchangeCodeForTokensAsync(string code,
string? codeVerifier = null)
{ {
var config = GetProviderConfig(); var config = GetProviderConfig();
var discoveryDocument = await GetDiscoveryDocumentAsync(); var discoveryDocument = await GetDiscoveryDocumentAsync();
@ -63,7 +68,7 @@ public class MicrosoftOidcService : OidcService
throw new InvalidOperationException("Token endpoint not found in discovery document."); throw new InvalidOperationException("Token endpoint not found in discovery document.");
} }
var client = _httpClientFactory.CreateClient(); var client = HttpClientFactory.CreateClient();
var tokenRequest = new HttpRequestMessage(HttpMethod.Post, discoveryDocument.TokenEndpoint) var tokenRequest = new HttpRequestMessage(HttpMethod.Post, discoveryDocument.TokenEndpoint)
{ {
@ -90,7 +95,7 @@ public class MicrosoftOidcService : OidcService
if (discoveryDocument?.UserinfoEndpoint == null) if (discoveryDocument?.UserinfoEndpoint == null)
throw new InvalidOperationException("Userinfo endpoint not found in discovery document."); throw new InvalidOperationException("Userinfo endpoint not found in discovery document.");
var client = _httpClientFactory.CreateClient(); var client = HttpClientFactory.CreateClient();
var request = new HttpRequestMessage(HttpMethod.Get, discoveryDocument.UserinfoEndpoint); var request = new HttpRequestMessage(HttpMethod.Get, discoveryDocument.UserinfoEndpoint);
request.Headers.Add("Authorization", $"Bearer {accessToken}"); request.Headers.Add("Authorization", $"Bearer {accessToken}");
@ -104,10 +109,15 @@ public class MicrosoftOidcService : OidcService
{ {
UserId = microsoftUser.GetProperty("sub").GetString() ?? "", UserId = microsoftUser.GetProperty("sub").GetString() ?? "",
Email = microsoftUser.TryGetProperty("email", out var emailElement) ? emailElement.GetString() : null, Email = microsoftUser.TryGetProperty("email", out var emailElement) ? emailElement.GetString() : null,
DisplayName = microsoftUser.TryGetProperty("name", out var nameElement) ? nameElement.GetString() ?? "" : "", DisplayName =
PreferredUsername = microsoftUser.TryGetProperty("preferred_username", out var preferredUsernameElement) ? preferredUsernameElement.GetString() ?? "" : "", microsoftUser.TryGetProperty("name", out var nameElement) ? nameElement.GetString() ?? "" : "",
ProfilePictureUrl = microsoftUser.TryGetProperty("picture", out var pictureElement) ? pictureElement.GetString() ?? "" : "", PreferredUsername = microsoftUser.TryGetProperty("preferred_username", out var preferredUsernameElement)
? preferredUsernameElement.GetString() ?? ""
: "",
ProfilePictureUrl = microsoftUser.TryGetProperty("picture", out var pictureElement)
? pictureElement.GetString() ?? ""
: "",
Provider = ProviderName Provider = ProviderName
}; };
} }
} }

View File

@ -2,6 +2,7 @@ using System.IdentityModel.Tokens.Jwt;
using System.Net.Http.Json; using System.Net.Http.Json;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using DysonNetwork.Sphere.Account; using DysonNetwork.Sphere.Account;
using DysonNetwork.Sphere.Storage;
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
using Microsoft.IdentityModel.Tokens; using Microsoft.IdentityModel.Tokens;
using NodaTime; using NodaTime;
@ -11,18 +12,16 @@ namespace DysonNetwork.Sphere.Auth.OpenId;
/// <summary> /// <summary>
/// Base service for OpenID Connect authentication providers /// Base service for OpenID Connect authentication providers
/// </summary> /// </summary>
public abstract class OidcService public abstract class OidcService(
IConfiguration configuration,
IHttpClientFactory httpClientFactory,
AppDatabase db,
ICacheService cache
)
{ {
protected readonly IConfiguration _configuration; protected readonly IConfiguration Configuration = configuration;
protected readonly IHttpClientFactory _httpClientFactory; protected readonly IHttpClientFactory HttpClientFactory = httpClientFactory;
protected readonly AppDatabase _db; protected readonly AppDatabase Db = db;
protected OidcService(IConfiguration configuration, IHttpClientFactory httpClientFactory, AppDatabase db)
{
_configuration = configuration;
_httpClientFactory = httpClientFactory;
_db = db;
}
/// <summary> /// <summary>
/// Gets the unique identifier for this provider /// Gets the unique identifier for this provider
@ -56,9 +55,9 @@ public abstract class OidcService
{ {
return new ProviderConfiguration return new ProviderConfiguration
{ {
ClientId = _configuration[$"Oidc:{ConfigSectionName}:ClientId"] ?? "", ClientId = Configuration[$"Oidc:{ConfigSectionName}:ClientId"] ?? "",
ClientSecret = _configuration[$"Oidc:{ConfigSectionName}:ClientSecret"] ?? "", ClientSecret = Configuration[$"Oidc:{ConfigSectionName}:ClientSecret"] ?? "",
RedirectUri = _configuration["BaseUrl"] + "/auth/callback/" + ProviderName.ToLower() RedirectUri = Configuration["BaseUrl"] + "/auth/callback/" + ProviderName.ToLower()
}; };
} }
@ -67,10 +66,28 @@ public abstract class OidcService
/// </summary> /// </summary>
protected async Task<OidcDiscoveryDocument?> GetDiscoveryDocumentAsync() protected async Task<OidcDiscoveryDocument?> GetDiscoveryDocumentAsync()
{ {
var client = _httpClientFactory.CreateClient(); // Construct a cache key unique to the current provider:
var cacheKey = $"oidc-discovery:{ProviderName}";
// Try getting the discovery document from cache first:
var (found, cachedDoc) = await cache.GetAsyncWithStatus<OidcDiscoveryDocument>(cacheKey);
if (found && cachedDoc != null)
{
return cachedDoc;
}
// If it's not cached, fetch from the actual discovery endpoint:
var client = HttpClientFactory.CreateClient();
var response = await client.GetAsync(DiscoveryEndpoint); var response = await client.GetAsync(DiscoveryEndpoint);
response.EnsureSuccessStatusCode(); response.EnsureSuccessStatusCode();
return await response.Content.ReadFromJsonAsync<OidcDiscoveryDocument>(); var doc = await response.Content.ReadFromJsonAsync<OidcDiscoveryDocument>();
// Store the discovery document in the cache for a while (e.g., 15 minutes):
if (doc is not null)
await cache.SetAsync(cacheKey, doc, TimeSpan.FromMinutes(15));
return doc;
} }
/// <summary> /// <summary>
@ -87,7 +104,7 @@ public abstract class OidcService
throw new InvalidOperationException("Token endpoint not found in discovery document"); throw new InvalidOperationException("Token endpoint not found in discovery document");
} }
var client = _httpClientFactory.CreateClient(); var client = HttpClientFactory.CreateClient();
var content = new FormUrlEncodedContent(BuildTokenRequestParameters(code, config, codeVerifier)); var content = new FormUrlEncodedContent(BuildTokenRequestParameters(code, config, codeVerifier));
var response = await client.PostAsync(discoveryDocument.TokenEndpoint, content); var response = await client.PostAsync(discoveryDocument.TokenEndpoint, content);
@ -178,7 +195,7 @@ public abstract class OidcService
) )
{ {
// Create or update the account connection // Create or update the account connection
var connection = await _db.AccountConnections var connection = await Db.AccountConnections
.FirstOrDefaultAsync(c => c.Provider == ProviderName && .FirstOrDefaultAsync(c => c.Provider == ProviderName &&
c.ProvidedIdentifier == userInfo.UserId && c.ProvidedIdentifier == userInfo.UserId &&
c.AccountId == account.Id c.AccountId == account.Id
@ -195,7 +212,7 @@ public abstract class OidcService
LastUsedAt = SystemClock.Instance.GetCurrentInstant(), LastUsedAt = SystemClock.Instance.GetCurrentInstant(),
AccountId = account.Id AccountId = account.Id
}; };
await _db.AccountConnections.AddAsync(connection); await Db.AccountConnections.AddAsync(connection);
} }
// Create a challenge that's already completed // Create a challenge that's already completed
@ -215,7 +232,7 @@ public abstract class OidcService
UserAgent = request.Request.Headers.UserAgent, UserAgent = request.Request.Headers.UserAgent,
}; };
await _db.AuthChallenges.AddAsync(challenge); await Db.AuthChallenges.AddAsync(challenge);
// Create a session // Create a session
var session = new Session var session = new Session
@ -226,8 +243,8 @@ public abstract class OidcService
Challenge = challenge Challenge = challenge
}; };
await _db.AuthSessions.AddAsync(session); await Db.AuthSessions.AddAsync(session);
await _db.SaveChangesAsync(); await Db.SaveChangesAsync();
return session; return session;
} }