🐛 Fixes in auth service

This commit is contained in:
2025-08-25 22:24:18 +08:00
parent 081815c512
commit 442ee3bcfd
4 changed files with 120 additions and 23 deletions

View File

@@ -49,7 +49,10 @@ public class DysonTokenAuthHandler(
try try
{ {
var (valid, session, message) = await token.AuthenticateTokenAsync(tokenInfo.Token); // Get client IP address
var ipAddress = Context.Connection.RemoteIpAddress?.ToString();
var (valid, session, message) = await token.AuthenticateTokenAsync(tokenInfo.Token, ipAddress);
if (!valid || session is null) if (!valid || session is null)
return AuthenticateResult.Fail(message ?? "Authentication failed."); return AuthenticateResult.Fail(message ?? "Authentication failed.");

View File

@@ -20,7 +20,6 @@ public class TokenResponse
[JsonPropertyName("scope")] [JsonPropertyName("scope")]
public string? Scope { get; set; } public string? Scope { get; set; }
[JsonPropertyName("id_token")] [JsonPropertyName("id_token")]
public string? IdToken { get; set; } public string? IdToken { get; set; }
} }

View File

@@ -11,6 +11,7 @@ using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using Microsoft.IdentityModel.Tokens; using Microsoft.IdentityModel.Tokens;
using NodaTime; using NodaTime;
using AccountContactType = DysonNetwork.Pass.Account.AccountContactType;
namespace DysonNetwork.Pass.Auth.OidcProvider.Services; namespace DysonNetwork.Pass.Auth.OidcProvider.Services;
@@ -37,12 +38,21 @@ public class OidcProviderService(
return resp.App ?? null; return resp.App ?? null;
} }
public async Task<AuthSession?> FindValidSessionAsync(Guid accountId, Guid clientId) public async Task<AuthSession?> FindValidSessionAsync(Guid accountId, Guid clientId, bool withAccount = false)
{ {
var now = SystemClock.Instance.GetCurrentInstant(); var now = SystemClock.Instance.GetCurrentInstant();
return await db.AuthSessions var queryable = db.AuthSessions
.Include(s => s.Challenge) .Include(s => s.Challenge)
.AsQueryable();
if (withAccount)
queryable = queryable
.Include(s => s.Account)
.ThenInclude(a => a.Profile)
.Include(a => a.Account.Contacts)
.AsQueryable();
return await queryable
.Where(s => s.AccountId == accountId && .Where(s => s.AccountId == accountId &&
s.AppId == clientId && s.AppId == clientId &&
(s.ExpiredAt == null || s.ExpiredAt > now) && (s.ExpiredAt == null || s.ExpiredAt > now) &&
@@ -67,12 +77,12 @@ public class OidcProviderService(
{ {
if (string.IsNullOrEmpty(redirectUri)) if (string.IsNullOrEmpty(redirectUri))
return false; return false;
var client = await FindClientByIdAsync(clientId); var client = await FindClientByIdAsync(clientId);
if (client?.Status != CustomAppStatus.Production) if (client?.Status != CustomAppStatus.Production)
return true; return true;
if (client?.OauthConfig?.RedirectUris == null) if (client?.OauthConfig?.RedirectUris == null)
return false; return false;
@@ -114,7 +124,7 @@ public class OidcProviderService(
var allowedPath = allowedUriObj.AbsolutePath.TrimEnd('/'); var allowedPath = allowedUriObj.AbsolutePath.TrimEnd('/');
var redirectPath = redirectUriObj.AbsolutePath.TrimEnd('/'); var redirectPath = redirectUriObj.AbsolutePath.TrimEnd('/');
if (string.IsNullOrEmpty(allowedPath) || if (string.IsNullOrEmpty(allowedPath) ||
redirectPath.StartsWith(allowedPath, StringComparison.OrdinalIgnoreCase)) redirectPath.StartsWith(allowedPath, StringComparison.OrdinalIgnoreCase))
{ {
return true; return true;
@@ -133,6 +143,79 @@ public class OidcProviderService(
return false; return false;
} }
private string GenerateIdToken(
CustomApp client,
AuthSession session,
string? nonce = null,
IEnumerable<string>? scopes = null
)
{
var tokenHandler = new JwtSecurityTokenHandler();
var clock = SystemClock.Instance;
var now = clock.GetCurrentInstant();
var claims = new List<Claim>
{
new Claim(JwtRegisteredClaimNames.Iss, _options.IssuerUri),
new Claim(JwtRegisteredClaimNames.Sub, session.AccountId.ToString()),
new Claim(JwtRegisteredClaimNames.Aud, client.Id.ToString()),
new Claim(JwtRegisteredClaimNames.Iat, now.ToUnixTimeSeconds().ToString(), ClaimValueTypes.Integer64),
new Claim(JwtRegisteredClaimNames.Exp,
now.Plus(Duration.FromSeconds(_options.AccessTokenLifetime.TotalSeconds)).ToUnixTimeSeconds()
.ToString(), ClaimValueTypes.Integer64),
new Claim(JwtRegisteredClaimNames.AuthTime, session.CreatedAt.ToUnixTimeSeconds().ToString(),
ClaimValueTypes.Integer64)
};
// Add nonce if provided (required for implicit and hybrid flows)
if (!string.IsNullOrEmpty(nonce))
{
claims.Add(new Claim("nonce", nonce));
}
// Add email claim if email scope is requested
var scopesList = scopes?.ToList() ?? [];
if (scopesList.Contains("email"))
{
var contact = session.Account.Contacts.FirstOrDefault(c => c.Type == AccountContactType.Email);
if (contact is not null)
{
claims.Add(new Claim(JwtRegisteredClaimNames.Email, contact.Content));
claims.Add(new Claim("email_verified", contact.VerifiedAt is not null ? "true" : "false",
ClaimValueTypes.Boolean));
}
}
// Add profile claims if profile scope is requested
if (scopes != null && scopesList.Contains("profile"))
{
if (!string.IsNullOrEmpty(session.Account.Name))
claims.Add(new Claim("preferred_username", session.Account.Name));
if (!string.IsNullOrEmpty(session.Account.Nick))
claims.Add(new Claim("name", session.Account.Nick));
if (!string.IsNullOrEmpty(session.Account.Profile.FirstName))
claims.Add(new Claim("given_name", session.Account.Profile.FirstName));
if (!string.IsNullOrEmpty(session.Account.Profile.LastName))
claims.Add(new Claim("family_name", session.Account.Profile.LastName));
}
var tokenDescriptor = new SecurityTokenDescriptor
{
Subject = new ClaimsIdentity(claims),
Issuer = _options.IssuerUri,
Audience = client.Id.ToString(),
Expires = now.Plus(Duration.FromSeconds(_options.AccessTokenLifetime.TotalSeconds)).ToDateTimeUtc(),
NotBefore = now.ToDateTimeUtc(),
SigningCredentials = new SigningCredentials(
new RsaSecurityKey(_options.GetRsaPrivateKey()),
SecurityAlgorithms.RsaSha256
)
};
var token = tokenHandler.CreateToken(tokenDescriptor);
return tokenHandler.WriteToken(token);
}
public async Task<TokenResponse> GenerateTokenResponseAsync( public async Task<TokenResponse> GenerateTokenResponseAsync(
Guid clientId, Guid clientId,
string? authorizationCode = null, string? authorizationCode = null,
@@ -148,24 +231,28 @@ public class OidcProviderService(
AuthSession session; AuthSession session;
var clock = SystemClock.Instance; var clock = SystemClock.Instance;
var now = clock.GetCurrentInstant(); var now = clock.GetCurrentInstant();
string? nonce = null;
List<string>? scopes = null; List<string>? scopes = null;
if (authorizationCode != null) if (authorizationCode != null)
{ {
// Authorization code flow // Authorization code flow
var authCode = await ValidateAuthorizationCodeAsync(authorizationCode, clientId, redirectUri, codeVerifier); var authCode = await ValidateAuthorizationCodeAsync(authorizationCode, clientId, redirectUri, codeVerifier);
if (authCode is null) throw new InvalidOperationException("Invalid authorization code"); if (authCode == null)
var account = await db.Accounts.Where(a => a.Id == authCode.AccountId).FirstOrDefaultAsync(); throw new InvalidOperationException("Invalid authorization code");
if (account is null) throw new InvalidOperationException("Account was not found");
// Load the session for the user
session = await FindValidSessionAsync(authCode.AccountId, clientId, withAccount: true) ??
throw new InvalidOperationException("No valid session found for user");
session = await auth.CreateSessionForOidcAsync(account, now, clientId);
scopes = authCode.Scopes; scopes = authCode.Scopes;
nonce = authCode.Nonce;
} }
else if (sessionId.HasValue) else if (sessionId.HasValue)
{ {
// Refresh token flow // Refresh token flow
session = await FindSessionByIdAsync(sessionId.Value) ?? session = await FindSessionByIdAsync(sessionId.Value) ??
throw new InvalidOperationException("Invalid session"); throw new InvalidOperationException("Session not found");
// Verify the session is still valid // Verify the session is still valid
if (session.ExpiredAt < now) if (session.ExpiredAt < now)
@@ -179,13 +266,15 @@ public class OidcProviderService(
var expiresIn = (int)_options.AccessTokenLifetime.TotalSeconds; var expiresIn = (int)_options.AccessTokenLifetime.TotalSeconds;
var expiresAt = now.Plus(Duration.FromSeconds(expiresIn)); var expiresAt = now.Plus(Duration.FromSeconds(expiresIn));
// Generate an access token // Generate tokens
var accessToken = GenerateJwtToken(client, session, expiresAt, scopes); var accessToken = GenerateJwtToken(client, session, expiresAt, scopes);
var idToken = GenerateIdToken(client, session, nonce, scopes);
var refreshToken = GenerateRefreshToken(session); var refreshToken = GenerateRefreshToken(session);
return new TokenResponse return new TokenResponse
{ {
AccessToken = accessToken, AccessToken = accessToken,
IdToken = idToken,
ExpiresIn = expiresIn, ExpiresIn = expiresIn,
TokenType = "Bearer", TokenType = "Bearer",
RefreshToken = refreshToken, RefreshToken = refreshToken,
@@ -317,12 +406,13 @@ public class OidcProviderService(
Nonce = nonce, Nonce = nonce,
CreatedAt = now CreatedAt = now
}; };
// Store the code with its metadata in the cache // Store the code with its metadata in the cache
var cacheKey = $"auth:code:{code}"; var cacheKey = $"auth:code:{code}";
await cache.SetAsync(cacheKey, authCodeInfo, _options.AuthorizationCodeLifetime); await cache.SetAsync(cacheKey, authCodeInfo, _options.AuthorizationCodeLifetime);
logger.LogInformation("Generated authorization code for client {ClientId} and user {UserId}", clientId, session.AccountId); logger.LogInformation("Generated authorization code for client {ClientId} and user {UserId}", clientId,
session.AccountId);
return code; return code;
} }

View File

@@ -23,7 +23,7 @@ public class TokenAuthService(
/// </summary> /// </summary>
/// <param name="token">Incoming token string</param> /// <param name="token">Incoming token string</param>
/// <returns>(Valid, Session, Message)</returns> /// <returns>(Valid, Session, Message)</returns>
public async Task<(bool Valid, AuthSession? Session, string? Message)> AuthenticateTokenAsync(string token) public async Task<(bool Valid, AuthSession? Session, string? Message)> AuthenticateTokenAsync(string token, string? ipAddress = null)
{ {
try try
{ {
@@ -32,6 +32,11 @@ public class TokenAuthService(
logger.LogWarning("AuthenticateTokenAsync: no token provided"); logger.LogWarning("AuthenticateTokenAsync: no token provided");
return (false, null, "No token provided."); return (false, null, "No token provided.");
} }
if (!string.IsNullOrEmpty(ipAddress))
{
logger.LogDebug("AuthenticateTokenAsync: client IP: {IpAddress}", ipAddress);
}
// token fingerprint for correlation // token fingerprint for correlation
var tokenHash = Convert.ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(token))); var tokenHash = Convert.ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(token)));
@@ -70,7 +75,7 @@ public class TokenAuthService(
"AuthenticateTokenAsync: success via cache (sessionId={SessionId}, accountId={AccountId}, scopes={ScopeCount}, expiresAt={ExpiresAt})", "AuthenticateTokenAsync: success via cache (sessionId={SessionId}, accountId={AccountId}, scopes={ScopeCount}, expiresAt={ExpiresAt})",
sessionId, sessionId,
session.AccountId, session.AccountId,
session.Challenge.Scopes.Count, session.Challenge?.Scopes.Count,
session.ExpiredAt session.ExpiredAt
); );
return (true, session, null); return (true, session, null);
@@ -103,11 +108,11 @@ public class TokenAuthService(
"AuthenticateTokenAsync: DB session loaded (sessionId={SessionId}, accountId={AccountId}, clientId={ClientId}, appId={AppId}, scopes={ScopeCount}, ip={Ip}, uaLen={UaLen})", "AuthenticateTokenAsync: DB session loaded (sessionId={SessionId}, accountId={AccountId}, clientId={ClientId}, appId={AppId}, scopes={ScopeCount}, ip={Ip}, uaLen={UaLen})",
sessionId, sessionId,
session.AccountId, session.AccountId,
session.Challenge.ClientId, session.Challenge?.ClientId,
session.AppId, session.AppId,
session.Challenge.Scopes.Count, session.Challenge?.Scopes.Count,
session.Challenge.IpAddress, session.Challenge?.IpAddress,
(session.Challenge.UserAgent ?? string.Empty).Length (session.Challenge?.UserAgent ?? string.Empty).Length
); );
logger.LogDebug("AuthenticateTokenAsync: enriching account with subscription (accountId={AccountId})", session.AccountId); logger.LogDebug("AuthenticateTokenAsync: enriching account with subscription (accountId={AccountId})", session.AccountId);
@@ -136,7 +141,7 @@ public class TokenAuthService(
"AuthenticateTokenAsync: success via DB (sessionId={SessionId}, accountId={AccountId}, clientId={ClientId})", "AuthenticateTokenAsync: success via DB (sessionId={SessionId}, accountId={AccountId}, clientId={ClientId})",
sessionId, sessionId,
session.AccountId, session.AccountId,
session.Challenge.ClientId session.Challenge?.ClientId
); );
return (true, session, null); return (true, session, null);
} }