diff --git a/DysonNetwork.Pass/Auth/Auth.cs b/DysonNetwork.Pass/Auth/Auth.cs index 0f5b019..30c26d5 100644 --- a/DysonNetwork.Pass/Auth/Auth.cs +++ b/DysonNetwork.Pass/Auth/Auth.cs @@ -49,7 +49,10 @@ public class DysonTokenAuthHandler( try { - var (valid, session, message) = await token.AuthenticateTokenAsync(tokenInfo.Token); + // Get client IP address + var ipAddress = Context.Connection.RemoteIpAddress?.ToString(); + + var (valid, session, message) = await token.AuthenticateTokenAsync(tokenInfo.Token, ipAddress); if (!valid || session is null) return AuthenticateResult.Fail(message ?? "Authentication failed."); diff --git a/DysonNetwork.Pass/Auth/OidcProvider/Responses/TokenResponse.cs b/DysonNetwork.Pass/Auth/OidcProvider/Responses/TokenResponse.cs index 3c7e50c..1eecfeb 100644 --- a/DysonNetwork.Pass/Auth/OidcProvider/Responses/TokenResponse.cs +++ b/DysonNetwork.Pass/Auth/OidcProvider/Responses/TokenResponse.cs @@ -20,7 +20,6 @@ public class TokenResponse [JsonPropertyName("scope")] public string? Scope { get; set; } - [JsonPropertyName("id_token")] public string? IdToken { get; set; } } diff --git a/DysonNetwork.Pass/Auth/OidcProvider/Services/OidcProviderService.cs b/DysonNetwork.Pass/Auth/OidcProvider/Services/OidcProviderService.cs index aa417b2..d3be946 100644 --- a/DysonNetwork.Pass/Auth/OidcProvider/Services/OidcProviderService.cs +++ b/DysonNetwork.Pass/Auth/OidcProvider/Services/OidcProviderService.cs @@ -11,6 +11,7 @@ using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Options; using Microsoft.IdentityModel.Tokens; using NodaTime; +using AccountContactType = DysonNetwork.Pass.Account.AccountContactType; namespace DysonNetwork.Pass.Auth.OidcProvider.Services; @@ -37,12 +38,21 @@ public class OidcProviderService( return resp.App ?? null; } - public async Task FindValidSessionAsync(Guid accountId, Guid clientId) + public async Task FindValidSessionAsync(Guid accountId, Guid clientId, bool withAccount = false) { var now = SystemClock.Instance.GetCurrentInstant(); - return await db.AuthSessions + var queryable = db.AuthSessions .Include(s => s.Challenge) + .AsQueryable(); + if (withAccount) + queryable = queryable + .Include(s => s.Account) + .ThenInclude(a => a.Profile) + .Include(a => a.Account.Contacts) + .AsQueryable(); + + return await queryable .Where(s => s.AccountId == accountId && s.AppId == clientId && (s.ExpiredAt == null || s.ExpiredAt > now) && @@ -67,12 +77,12 @@ public class OidcProviderService( { if (string.IsNullOrEmpty(redirectUri)) return false; - + var client = await FindClientByIdAsync(clientId); if (client?.Status != CustomAppStatus.Production) return true; - + if (client?.OauthConfig?.RedirectUris == null) return false; @@ -114,7 +124,7 @@ public class OidcProviderService( var allowedPath = allowedUriObj.AbsolutePath.TrimEnd('/'); var redirectPath = redirectUriObj.AbsolutePath.TrimEnd('/'); - if (string.IsNullOrEmpty(allowedPath) || + if (string.IsNullOrEmpty(allowedPath) || redirectPath.StartsWith(allowedPath, StringComparison.OrdinalIgnoreCase)) { return true; @@ -133,6 +143,79 @@ public class OidcProviderService( return false; } + private string GenerateIdToken( + CustomApp client, + AuthSession session, + string? nonce = null, + IEnumerable? scopes = null + ) + { + var tokenHandler = new JwtSecurityTokenHandler(); + var clock = SystemClock.Instance; + var now = clock.GetCurrentInstant(); + + var claims = new List + { + new Claim(JwtRegisteredClaimNames.Iss, _options.IssuerUri), + new Claim(JwtRegisteredClaimNames.Sub, session.AccountId.ToString()), + new Claim(JwtRegisteredClaimNames.Aud, client.Id.ToString()), + new Claim(JwtRegisteredClaimNames.Iat, now.ToUnixTimeSeconds().ToString(), ClaimValueTypes.Integer64), + new Claim(JwtRegisteredClaimNames.Exp, + now.Plus(Duration.FromSeconds(_options.AccessTokenLifetime.TotalSeconds)).ToUnixTimeSeconds() + .ToString(), ClaimValueTypes.Integer64), + new Claim(JwtRegisteredClaimNames.AuthTime, session.CreatedAt.ToUnixTimeSeconds().ToString(), + ClaimValueTypes.Integer64) + }; + + // Add nonce if provided (required for implicit and hybrid flows) + if (!string.IsNullOrEmpty(nonce)) + { + claims.Add(new Claim("nonce", nonce)); + } + + // Add email claim if email scope is requested + var scopesList = scopes?.ToList() ?? []; + if (scopesList.Contains("email")) + { + var contact = session.Account.Contacts.FirstOrDefault(c => c.Type == AccountContactType.Email); + if (contact is not null) + { + claims.Add(new Claim(JwtRegisteredClaimNames.Email, contact.Content)); + claims.Add(new Claim("email_verified", contact.VerifiedAt is not null ? "true" : "false", + ClaimValueTypes.Boolean)); + } + } + + // Add profile claims if profile scope is requested + if (scopes != null && scopesList.Contains("profile")) + { + if (!string.IsNullOrEmpty(session.Account.Name)) + claims.Add(new Claim("preferred_username", session.Account.Name)); + if (!string.IsNullOrEmpty(session.Account.Nick)) + claims.Add(new Claim("name", session.Account.Nick)); + if (!string.IsNullOrEmpty(session.Account.Profile.FirstName)) + claims.Add(new Claim("given_name", session.Account.Profile.FirstName)); + if (!string.IsNullOrEmpty(session.Account.Profile.LastName)) + claims.Add(new Claim("family_name", session.Account.Profile.LastName)); + } + + var tokenDescriptor = new SecurityTokenDescriptor + { + Subject = new ClaimsIdentity(claims), + Issuer = _options.IssuerUri, + Audience = client.Id.ToString(), + Expires = now.Plus(Duration.FromSeconds(_options.AccessTokenLifetime.TotalSeconds)).ToDateTimeUtc(), + NotBefore = now.ToDateTimeUtc(), + SigningCredentials = new SigningCredentials( + new RsaSecurityKey(_options.GetRsaPrivateKey()), + SecurityAlgorithms.RsaSha256 + ) + }; + + var token = tokenHandler.CreateToken(tokenDescriptor); + return tokenHandler.WriteToken(token); + } + public async Task GenerateTokenResponseAsync( Guid clientId, string? authorizationCode = null, @@ -148,24 +231,28 @@ public class OidcProviderService( AuthSession session; var clock = SystemClock.Instance; var now = clock.GetCurrentInstant(); - + string? nonce = null; List? scopes = null; + if (authorizationCode != null) { // Authorization code flow var authCode = await ValidateAuthorizationCodeAsync(authorizationCode, clientId, redirectUri, codeVerifier); - if (authCode is null) throw new InvalidOperationException("Invalid authorization code"); - var account = await db.Accounts.Where(a => a.Id == authCode.AccountId).FirstOrDefaultAsync(); - if (account is null) throw new InvalidOperationException("Account was not found"); + if (authCode == null) + throw new InvalidOperationException("Invalid authorization code"); + + // Load the session for the user + session = await FindValidSessionAsync(authCode.AccountId, clientId, withAccount: true) ?? + throw new InvalidOperationException("No valid session found for user"); - session = await auth.CreateSessionForOidcAsync(account, now, clientId); scopes = authCode.Scopes; + nonce = authCode.Nonce; } else if (sessionId.HasValue) { // Refresh token flow session = await FindSessionByIdAsync(sessionId.Value) ?? - throw new InvalidOperationException("Invalid session"); + throw new InvalidOperationException("Session not found"); // Verify the session is still valid if (session.ExpiredAt < now) @@ -179,13 +266,15 @@ public class OidcProviderService( var expiresIn = (int)_options.AccessTokenLifetime.TotalSeconds; var expiresAt = now.Plus(Duration.FromSeconds(expiresIn)); - // Generate an access token + // Generate tokens var accessToken = GenerateJwtToken(client, session, expiresAt, scopes); + var idToken = GenerateIdToken(client, session, nonce, scopes); var refreshToken = GenerateRefreshToken(session); return new TokenResponse { AccessToken = accessToken, + IdToken = idToken, ExpiresIn = expiresIn, TokenType = "Bearer", RefreshToken = refreshToken, @@ -317,12 +406,13 @@ public class OidcProviderService( Nonce = nonce, CreatedAt = now }; - + // Store the code with its metadata in the cache var cacheKey = $"auth:code:{code}"; await cache.SetAsync(cacheKey, authCodeInfo, _options.AuthorizationCodeLifetime); - logger.LogInformation("Generated authorization code for client {ClientId} and user {UserId}", clientId, session.AccountId); + logger.LogInformation("Generated authorization code for client {ClientId} and user {UserId}", clientId, + session.AccountId); return code; } diff --git a/DysonNetwork.Pass/Auth/TokenAuthService.cs b/DysonNetwork.Pass/Auth/TokenAuthService.cs index 371eb9d..c139f0e 100644 --- a/DysonNetwork.Pass/Auth/TokenAuthService.cs +++ b/DysonNetwork.Pass/Auth/TokenAuthService.cs @@ -23,7 +23,7 @@ public class TokenAuthService( /// /// Incoming token string /// (Valid, Session, Message) - public async Task<(bool Valid, AuthSession? Session, string? Message)> AuthenticateTokenAsync(string token) + public async Task<(bool Valid, AuthSession? Session, string? Message)> AuthenticateTokenAsync(string token, string? ipAddress = null) { try { @@ -32,6 +32,11 @@ public class TokenAuthService( logger.LogWarning("AuthenticateTokenAsync: no token provided"); return (false, null, "No token provided."); } + + if (!string.IsNullOrEmpty(ipAddress)) + { + logger.LogDebug("AuthenticateTokenAsync: client IP: {IpAddress}", ipAddress); + } // token fingerprint for correlation var tokenHash = Convert.ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(token))); @@ -70,7 +75,7 @@ public class TokenAuthService( "AuthenticateTokenAsync: success via cache (sessionId={SessionId}, accountId={AccountId}, scopes={ScopeCount}, expiresAt={ExpiresAt})", sessionId, session.AccountId, - session.Challenge.Scopes.Count, + session.Challenge?.Scopes.Count, session.ExpiredAt ); return (true, session, null); @@ -103,11 +108,11 @@ public class TokenAuthService( "AuthenticateTokenAsync: DB session loaded (sessionId={SessionId}, accountId={AccountId}, clientId={ClientId}, appId={AppId}, scopes={ScopeCount}, ip={Ip}, uaLen={UaLen})", sessionId, session.AccountId, - session.Challenge.ClientId, + session.Challenge?.ClientId, session.AppId, - session.Challenge.Scopes.Count, - session.Challenge.IpAddress, - (session.Challenge.UserAgent ?? string.Empty).Length + session.Challenge?.Scopes.Count, + session.Challenge?.IpAddress, + (session.Challenge?.UserAgent ?? string.Empty).Length ); logger.LogDebug("AuthenticateTokenAsync: enriching account with subscription (accountId={AccountId})", session.AccountId); @@ -136,7 +141,7 @@ public class TokenAuthService( "AuthenticateTokenAsync: success via DB (sessionId={SessionId}, accountId={AccountId}, clientId={ClientId})", sessionId, session.AccountId, - session.Challenge.ClientId + session.Challenge?.ClientId ); return (true, session, null); }