♻️ Refactored oidc
This commit is contained in:
@ -2,13 +2,11 @@ using System.IdentityModel.Tokens.Jwt;
|
||||
using System.Security.Claims;
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
using System.Text.Encodings.Web;
|
||||
using System.Text.Json;
|
||||
using DysonNetwork.Sphere.Auth.OidcProvider.Models;
|
||||
using DysonNetwork.Sphere.Auth.OidcProvider.Options;
|
||||
using DysonNetwork.Sphere.Auth.OidcProvider.Responses;
|
||||
using DysonNetwork.Sphere.Developer;
|
||||
using Microsoft.AspNetCore.Identity;
|
||||
using DysonNetwork.Sphere.Storage;
|
||||
using Microsoft.EntityFrameworkCore;
|
||||
using Microsoft.Extensions.Options;
|
||||
using Microsoft.IdentityModel.Tokens;
|
||||
@ -18,7 +16,8 @@ namespace DysonNetwork.Sphere.Auth.OidcProvider.Services;
|
||||
|
||||
public class OidcProviderService(
|
||||
AppDatabase db,
|
||||
IClock clock,
|
||||
AuthService auth,
|
||||
ICacheService cache,
|
||||
IOptions<OidcProviderOptions> options,
|
||||
ILogger<OidcProviderService> logger
|
||||
)
|
||||
@ -44,6 +43,7 @@ public class OidcProviderService(
|
||||
var client = await FindClientByIdAsync(clientId);
|
||||
if (client == null) return false;
|
||||
|
||||
var clock = SystemClock.Instance;
|
||||
var secret = client.Secrets
|
||||
.Where(s => s.IsOidc && (s.ExpiredAt == null || s.ExpiredAt > clock.GetCurrentInstant()))
|
||||
.FirstOrDefault(s => s.Secret == clientSecret); // In production, use proper hashing
|
||||
@ -53,20 +53,28 @@ public class OidcProviderService(
|
||||
|
||||
public async Task<TokenResponse> GenerateTokenResponseAsync(
|
||||
Guid clientId,
|
||||
string subjectId,
|
||||
IEnumerable<string>? scopes = null,
|
||||
string? authorizationCode = null)
|
||||
string authorizationCode,
|
||||
IEnumerable<string>? scopes = null
|
||||
)
|
||||
{
|
||||
var client = await FindClientByIdAsync(clientId);
|
||||
if (client == null)
|
||||
throw new InvalidOperationException("Client not found");
|
||||
|
||||
var authCode = await ValidateAuthorizationCodeAsync(authorizationCode, clientId);
|
||||
if (authCode is null) throw new InvalidOperationException("Invalid authorization code");
|
||||
var account = await db.Accounts.Where(a => a.Id == authCode.AccountId).FirstOrDefaultAsync();
|
||||
if (account is null) throw new InvalidOperationException("Account was not found");
|
||||
|
||||
var clock = SystemClock.Instance;
|
||||
var now = clock.GetCurrentInstant();
|
||||
var session = await auth.CreateSessionAsync(account, now);
|
||||
|
||||
var expiresIn = (int)_options.AccessTokenLifetime.TotalSeconds;
|
||||
var expiresAt = now.Plus(Duration.FromSeconds(expiresIn));
|
||||
|
||||
// Generate access token
|
||||
var accessToken = GenerateJwtToken(client, subjectId, expiresAt, scopes);
|
||||
var accessToken = GenerateJwtToken(client, session, expiresAt, scopes);
|
||||
var refreshToken = GenerateRefreshToken();
|
||||
|
||||
// In a real implementation, you would store the token in the database
|
||||
@ -83,27 +91,33 @@ public class OidcProviderService(
|
||||
};
|
||||
}
|
||||
|
||||
private string GenerateJwtToken(CustomApp client, string subjectId, Instant expiresAt, IEnumerable<string>? scopes = null)
|
||||
private string GenerateJwtToken(
|
||||
CustomApp client,
|
||||
Session session,
|
||||
Instant expiresAt,
|
||||
IEnumerable<string>? scopes = null
|
||||
)
|
||||
{
|
||||
var tokenHandler = new JwtSecurityTokenHandler();
|
||||
var key = Encoding.ASCII.GetBytes(_options.SigningKey);
|
||||
|
||||
var clock = SystemClock.Instance;
|
||||
var tokenDescriptor = new SecurityTokenDescriptor
|
||||
{
|
||||
Subject = new ClaimsIdentity(new[]
|
||||
{
|
||||
new Claim(JwtRegisteredClaimNames.Sub, subjectId),
|
||||
Subject = new ClaimsIdentity([
|
||||
new Claim(JwtRegisteredClaimNames.Sub, session.Id.ToString()),
|
||||
new Claim(JwtRegisteredClaimNames.Jti, Guid.NewGuid().ToString()),
|
||||
new Claim(JwtRegisteredClaimNames.Iat, clock.GetCurrentInstant().ToUnixTimeSeconds().ToString(),
|
||||
ClaimValueTypes.Integer64),
|
||||
new Claim("client_id", client.Id.ToString())
|
||||
}),
|
||||
]),
|
||||
Expires = expiresAt.ToDateTimeUtc(),
|
||||
Issuer = _options.IssuerUri,
|
||||
Audience = client.Id.ToString(),
|
||||
SigningCredentials = new SigningCredentials(
|
||||
new SymmetricSecurityKey(key),
|
||||
SecurityAlgorithms.HmacSha256Signature)
|
||||
SecurityAlgorithms.HmacSha256Signature
|
||||
)
|
||||
};
|
||||
|
||||
// Add scopes as claims if provided, otherwise use client's default scopes
|
||||
@ -163,11 +177,11 @@ public class OidcProviderService(
|
||||
}
|
||||
}
|
||||
|
||||
private static readonly Dictionary<string, AuthorizationCodeInfo> _authorizationCodes = new();
|
||||
|
||||
// Authorization codes are now managed through ICacheService
|
||||
|
||||
public async Task<string> GenerateAuthorizationCodeAsync(
|
||||
Guid clientId,
|
||||
string userId,
|
||||
Guid userId,
|
||||
string redirectUri,
|
||||
IEnumerable<string> scopes,
|
||||
string? codeChallenge = null,
|
||||
@ -175,14 +189,15 @@ public class OidcProviderService(
|
||||
string? nonce = null)
|
||||
{
|
||||
// Generate a random code
|
||||
var clock = SystemClock.Instance;
|
||||
var code = GenerateRandomString(32);
|
||||
var now = clock.GetCurrentInstant();
|
||||
|
||||
// Store the code with its metadata
|
||||
_authorizationCodes[code] = new AuthorizationCodeInfo
|
||||
|
||||
// Create the authorization code info
|
||||
var authCodeInfo = new AuthorizationCodeInfo
|
||||
{
|
||||
ClientId = clientId,
|
||||
UserId = userId,
|
||||
AccountId = userId,
|
||||
RedirectUri = redirectUri,
|
||||
Scopes = scopes.ToList(),
|
||||
CodeChallenge = codeChallenge,
|
||||
@ -191,48 +206,58 @@ public class OidcProviderService(
|
||||
Expiration = now.Plus(Duration.FromTimeSpan(_options.AuthorizationCodeLifetime)),
|
||||
CreatedAt = now
|
||||
};
|
||||
|
||||
|
||||
// Store the code with its metadata in the cache
|
||||
var cacheKey = $"auth:code:{code}";
|
||||
await cache.SetAsync(cacheKey, authCodeInfo, _options.AuthorizationCodeLifetime);
|
||||
|
||||
logger.LogInformation("Generated authorization code for client {ClientId} and user {UserId}", clientId, userId);
|
||||
return code;
|
||||
}
|
||||
|
||||
|
||||
public async Task<AuthorizationCodeInfo?> ValidateAuthorizationCodeAsync(
|
||||
string code,
|
||||
Guid clientId,
|
||||
string? redirectUri = null,
|
||||
string? codeVerifier = null)
|
||||
string? codeVerifier = null
|
||||
)
|
||||
{
|
||||
if (!_authorizationCodes.TryGetValue(code, out var authCode) || authCode == null)
|
||||
var cacheKey = $"auth:code:{code}";
|
||||
var (found, authCode) = await cache.GetAsyncWithStatus<AuthorizationCodeInfo>(cacheKey);
|
||||
|
||||
if (!found || authCode == null)
|
||||
{
|
||||
logger.LogWarning("Authorization code not found: {Code}", code);
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
var clock = SystemClock.Instance;
|
||||
var now = clock.GetCurrentInstant();
|
||||
|
||||
// Check if code has expired
|
||||
|
||||
// Check if the code has expired
|
||||
if (now > authCode.Expiration)
|
||||
{
|
||||
logger.LogWarning("Authorization code expired: {Code}", code);
|
||||
_authorizationCodes.Remove(code);
|
||||
await cache.RemoveAsync(cacheKey);
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
// Verify client ID matches
|
||||
if (authCode.ClientId != clientId)
|
||||
{
|
||||
logger.LogWarning("Client ID mismatch for code {Code}. Expected: {ExpectedClientId}, Actual: {ActualClientId}",
|
||||
logger.LogWarning(
|
||||
"Client ID mismatch for code {Code}. Expected: {ExpectedClientId}, Actual: {ActualClientId}",
|
||||
code, authCode.ClientId, clientId);
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
// Verify redirect URI if provided
|
||||
if (!string.IsNullOrEmpty(redirectUri) && authCode.RedirectUri != redirectUri)
|
||||
{
|
||||
logger.LogWarning("Redirect URI mismatch for code {Code}", code);
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
// Verify PKCE code challenge if one was provided during authorization
|
||||
if (!string.IsNullOrEmpty(authCode.CodeChallenge))
|
||||
{
|
||||
@ -241,33 +266,33 @@ public class OidcProviderService(
|
||||
logger.LogWarning("PKCE code verifier is required but not provided for code {Code}", code);
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
var isValid = authCode.CodeChallengeMethod?.ToUpperInvariant() switch
|
||||
{
|
||||
"S256" => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, "S256"),
|
||||
"PLAIN" => VerifyCodeChallenge(codeVerifier, authCode.CodeChallenge, "PLAIN"),
|
||||
_ => false // Unsupported code challenge method
|
||||
};
|
||||
|
||||
|
||||
if (!isValid)
|
||||
{
|
||||
logger.LogWarning("PKCE code verifier validation failed for code {Code}", code);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
// Code is valid, remove it from the store (codes are single-use)
|
||||
_authorizationCodes.Remove(code);
|
||||
|
||||
|
||||
// Code is valid, remove it from the cache (codes are single-use)
|
||||
await cache.RemoveAsync(cacheKey);
|
||||
|
||||
return authCode;
|
||||
}
|
||||
|
||||
|
||||
private static string GenerateRandomString(int length)
|
||||
{
|
||||
const string chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~";
|
||||
var random = RandomNumberGenerator.Create();
|
||||
var result = new char[length];
|
||||
|
||||
|
||||
for (int i = 0; i < length; i++)
|
||||
{
|
||||
var randomNumber = new byte[4];
|
||||
@ -275,14 +300,14 @@ public class OidcProviderService(
|
||||
var index = (int)(BitConverter.ToUInt32(randomNumber, 0) % chars.Length);
|
||||
result[i] = chars[index];
|
||||
}
|
||||
|
||||
|
||||
return new string(result);
|
||||
}
|
||||
|
||||
|
||||
private static bool VerifyCodeChallenge(string codeVerifier, string codeChallenge, string method)
|
||||
{
|
||||
if (string.IsNullOrEmpty(codeVerifier)) return false;
|
||||
|
||||
|
||||
if (method == "S256")
|
||||
{
|
||||
using var sha256 = SHA256.Create();
|
||||
@ -290,12 +315,12 @@ public class OidcProviderService(
|
||||
var base64 = Base64UrlEncoder.Encode(hash);
|
||||
return string.Equals(base64, codeChallenge, StringComparison.Ordinal);
|
||||
}
|
||||
|
||||
|
||||
if (method == "PLAIN")
|
||||
{
|
||||
return string.Equals(codeVerifier, codeChallenge, StringComparison.Ordinal);
|
||||
}
|
||||
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user