♻️ Refactored auth service for better security
This commit is contained in:
@@ -30,7 +30,7 @@ public class AuthController(
|
||||
|
||||
public class ChallengeRequest
|
||||
{
|
||||
[Required] public ClientPlatform Platform { get; set; }
|
||||
[Required] public Shared.Models.ClientPlatform Platform { get; set; }
|
||||
[Required] [MaxLength(256)] public string Account { get; set; } = null!;
|
||||
[Required] [MaxLength(512)] public string DeviceId { get; set; } = null!;
|
||||
[MaxLength(1024)] public string? DeviceName { get; set; }
|
||||
@@ -61,9 +61,6 @@ public class AuthController(
|
||||
|
||||
request.DeviceName ??= userAgent;
|
||||
|
||||
var device =
|
||||
await auth.GetOrCreateDeviceAsync(account.Id, request.DeviceId, request.DeviceName, request.Platform);
|
||||
|
||||
// Trying to pick up challenges from the same IP address and user agent
|
||||
var existingChallenge = await db.AuthChallenges
|
||||
.Where(e => e.AccountId == account.Id)
|
||||
@@ -72,7 +69,7 @@ public class AuthController(
|
||||
.Where(e => e.StepRemain > 0)
|
||||
.Where(e => e.ExpiredAt != null && now < e.ExpiredAt)
|
||||
.Where(e => e.Type == Shared.Models.ChallengeType.Login)
|
||||
.Where(e => e.ClientId == device.Id)
|
||||
.Where(e => e.DeviceId == request.DeviceId)
|
||||
.FirstOrDefaultAsync();
|
||||
if (existingChallenge is not null)
|
||||
{
|
||||
@@ -90,7 +87,9 @@ public class AuthController(
|
||||
IpAddress = ipAddress,
|
||||
UserAgent = userAgent,
|
||||
Location = geo.GetPointFromIp(ipAddress),
|
||||
ClientId = device.Id,
|
||||
DeviceId = request.DeviceId,
|
||||
DeviceName = request.DeviceName,
|
||||
Platform = request.Platform,
|
||||
AccountId = account.Id
|
||||
}.Normalize();
|
||||
|
||||
@@ -176,7 +175,6 @@ public class AuthController(
|
||||
{
|
||||
var challenge = await db.AuthChallenges
|
||||
.Include(e => e.Account)
|
||||
.Include(authChallenge => authChallenge.Client)
|
||||
.FirstOrDefaultAsync(e => e.Id == id);
|
||||
if (challenge is null) return NotFound("Auth challenge was not found.");
|
||||
|
||||
@@ -246,7 +244,7 @@ public class AuthController(
|
||||
{
|
||||
Topic = "auth.login",
|
||||
Title = localizer["NewLoginTitle"],
|
||||
Body = localizer["NewLoginBody", challenge.Client?.DeviceName ?? "unknown",
|
||||
Body = localizer["NewLoginBody", challenge.DeviceName ?? "unknown",
|
||||
challenge.IpAddress ?? "unknown"],
|
||||
IsSavable = true
|
||||
},
|
||||
|
||||
@@ -157,7 +157,7 @@ public class AuthService(
|
||||
// 8) Device Trust Assessment
|
||||
var trustedDeviceIds = recentSessions
|
||||
.Where(s => s.CreatedAt > now.Minus(Duration.FromDays(30))) // Trust devices from last 30 days
|
||||
.Select(s => s.Challenge?.ClientId)
|
||||
.Select(s => s.ClientId)
|
||||
.Where(id => id.HasValue)
|
||||
.Distinct()
|
||||
.ToList();
|
||||
@@ -182,7 +182,7 @@ public class AuthService(
|
||||
}
|
||||
|
||||
public async Task<SnAuthSession> CreateSessionForOidcAsync(SnAccount account, Instant time,
|
||||
Guid? customAppId = null)
|
||||
Guid? customAppId = null, SnAuthSession? parentSession = null)
|
||||
{
|
||||
var challenge = new SnAuthChallenge
|
||||
{
|
||||
@@ -191,7 +191,10 @@ public class AuthService(
|
||||
UserAgent = HttpContext.Request.Headers.UserAgent,
|
||||
StepRemain = 1,
|
||||
StepTotal = 1,
|
||||
Type = customAppId is not null ? ChallengeType.OAuth : ChallengeType.Oidc
|
||||
Type = customAppId is not null ? ChallengeType.OAuth : ChallengeType.Oidc,
|
||||
DeviceId = Guid.NewGuid().ToString(),
|
||||
DeviceName = "OIDC/OAuth",
|
||||
Platform = ClientPlatform.Web,
|
||||
};
|
||||
|
||||
var session = new SnAuthSession
|
||||
@@ -200,7 +203,8 @@ public class AuthService(
|
||||
CreatedAt = time,
|
||||
LastGrantedAt = time,
|
||||
Challenge = challenge,
|
||||
AppId = customAppId
|
||||
AppId = customAppId,
|
||||
ParentSessionId = parentSession?.Id
|
||||
};
|
||||
|
||||
db.AuthChallenges.Add(challenge);
|
||||
@@ -288,35 +292,75 @@ public class AuthService(
|
||||
|
||||
/// <summary>
|
||||
/// Immediately revoke a session by setting expiry to now and clearing from cache
|
||||
/// This provides immediate invalidation of tokens and sessions
|
||||
/// This provides immediate invalidation of tokens and sessions, including all child sessions recursively.
|
||||
/// </summary>
|
||||
/// <param name="sessionId">Session ID to revoke</param>
|
||||
/// <returns>True if session was found and revoked, false otherwise</returns>
|
||||
public async Task<bool> RevokeSessionAsync(Guid sessionId)
|
||||
{
|
||||
var session = await db.AuthSessions.FirstOrDefaultAsync(s => s.Id == sessionId);
|
||||
if (session == null)
|
||||
var sessionsToRevokeIds = new HashSet<Guid>();
|
||||
await CollectSessionsToRevoke(sessionId, sessionsToRevokeIds);
|
||||
|
||||
if (sessionsToRevokeIds.Count == 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Set expiry to now (immediate invalidation)
|
||||
var now = SystemClock.Instance.GetCurrentInstant();
|
||||
session.ExpiredAt = now;
|
||||
db.AuthSessions.Update(session);
|
||||
var accountIdsToClearCache = new HashSet<Guid>();
|
||||
|
||||
// Clear from cache immediately
|
||||
var cacheKey = $"{AuthCachePrefix}{session.Id}";
|
||||
await cache.RemoveAsync(cacheKey);
|
||||
// Fetch all sessions to be revoked in one go
|
||||
var sessions = await db.AuthSessions
|
||||
.Where(s => sessionsToRevokeIds.Contains(s.Id))
|
||||
.ToListAsync();
|
||||
|
||||
// Clear account-level cache groups that include this session
|
||||
await cache.RemoveAsync($"{AuthCachePrefix}{session.AccountId}");
|
||||
foreach (var session in sessions)
|
||||
{
|
||||
session.ExpiredAt = now;
|
||||
accountIdsToClearCache.Add(session.AccountId);
|
||||
|
||||
// Clear from cache immediately for each session
|
||||
await cache.RemoveAsync($"{AuthCachePrefix}{session.Id}");
|
||||
}
|
||||
|
||||
db.AuthSessions.UpdateRange(sessions);
|
||||
await db.SaveChangesAsync();
|
||||
|
||||
// Clear account-level cache groups
|
||||
foreach (var accountId in accountIdsToClearCache)
|
||||
{
|
||||
await cache.RemoveAsync($"{AuthCachePrefix}{accountId}");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Recursively collects all session IDs that need to be revoked, starting from a given session.
|
||||
/// </summary>
|
||||
/// <param name="currentSessionId">The session ID to start collecting from.</param>
|
||||
/// <param name="sessionsToRevoke">A HashSet to store the IDs of all sessions to be revoked.</param>
|
||||
private async Task CollectSessionsToRevoke(Guid currentSessionId, HashSet<Guid> sessionsToRevoke)
|
||||
{
|
||||
if (sessionsToRevoke.Contains(currentSessionId))
|
||||
{
|
||||
return; // Already processed this session
|
||||
}
|
||||
|
||||
sessionsToRevoke.Add(currentSessionId);
|
||||
|
||||
// Find direct children
|
||||
var childSessions = await db.AuthSessions
|
||||
.Where(s => s.ParentSessionId == currentSessionId)
|
||||
.Select(s => s.Id)
|
||||
.ToListAsync();
|
||||
|
||||
foreach (var childId in childSessions)
|
||||
{
|
||||
await CollectSessionsToRevoke(childId, sessionsToRevoke);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Revoke all sessions for an account (logout everywhere)
|
||||
/// </summary>
|
||||
@@ -380,13 +424,17 @@ public class AuthService(
|
||||
if (hasSession)
|
||||
throw new ArgumentException("Session already exists for this challenge.");
|
||||
|
||||
var device = await GetOrCreateDeviceAsync(challenge.AccountId, challenge.DeviceId, challenge.DeviceName,
|
||||
challenge.Platform);
|
||||
|
||||
var now = SystemClock.Instance.GetCurrentInstant();
|
||||
var session = new SnAuthSession
|
||||
{
|
||||
LastGrantedAt = now,
|
||||
ExpiredAt = now.Plus(Duration.FromDays(7)),
|
||||
AccountId = challenge.AccountId,
|
||||
ChallengeId = challenge.Id
|
||||
ChallengeId = challenge.Id,
|
||||
ClientId = device.Id,
|
||||
};
|
||||
|
||||
db.AuthSessions.Add(session);
|
||||
@@ -500,7 +548,7 @@ public class AuthService(
|
||||
return key;
|
||||
}
|
||||
|
||||
public async Task<SnApiKey> CreateApiKey(Guid accountId, string label, Instant? expiredAt = null)
|
||||
public async Task<SnApiKey> CreateApiKey(Guid accountId, string label, Instant? expiredAt = null, SnAuthSession? parentSession = null)
|
||||
{
|
||||
var key = new SnApiKey
|
||||
{
|
||||
@@ -509,7 +557,8 @@ public class AuthService(
|
||||
Session = new SnAuthSession
|
||||
{
|
||||
AccountId = accountId,
|
||||
ExpiredAt = expiredAt
|
||||
ExpiredAt = expiredAt,
|
||||
ParentSessionId = parentSession?.Id
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -342,13 +342,19 @@ public class ConnectionController(
|
||||
callbackData.State.Split('|').FirstOrDefault() :
|
||||
string.Empty;
|
||||
|
||||
var challenge = await oidcService.CreateChallengeForUserAsync(
|
||||
if (HttpContext.Items["CurrentSession"] is not SnAuthSession parentSession) parentSession = null;
|
||||
|
||||
var session = await oidcService.CreateSessionForUserAsync(
|
||||
userInfo,
|
||||
connection.Account,
|
||||
HttpContext,
|
||||
deviceId ?? string.Empty);
|
||||
deviceId ?? string.Empty,
|
||||
null,
|
||||
ClientPlatform.Web,
|
||||
parentSession);
|
||||
|
||||
var redirectUrl = QueryHelpers.AddQueryString(redirectBaseUrl, "challenge", challenge.Id.ToString());
|
||||
var token = auth.CreateToken(session);
|
||||
var redirectUrl = QueryHelpers.AddQueryString(redirectBaseUrl, "token", token);
|
||||
logger.LogInformation("OIDC login successful for user {UserId}. Redirecting to {RedirectUrl}", connection.AccountId, redirectUrl);
|
||||
return Redirect(redirectUrl);
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ public class OidcController(
|
||||
IServiceProvider serviceProvider,
|
||||
AppDatabase db,
|
||||
AccountService accounts,
|
||||
AuthService auth,
|
||||
ICacheService cache,
|
||||
ILogger<OidcController> logger
|
||||
)
|
||||
@@ -22,6 +23,11 @@ public class OidcController(
|
||||
private const string StateCachePrefix = "oidc-state:";
|
||||
private static readonly TimeSpan StateExpiration = TimeSpan.FromMinutes(15);
|
||||
|
||||
public class TokenExchangeResponse
|
||||
{
|
||||
public string Token { get; set; } = string.Empty;
|
||||
}
|
||||
|
||||
[HttpGet("{provider}")]
|
||||
public async Task<ActionResult> OidcLogin(
|
||||
[FromRoute] string provider,
|
||||
@@ -75,7 +81,7 @@ public class OidcController(
|
||||
/// Handles Apple authentication directly from mobile apps
|
||||
/// </summary>
|
||||
[HttpPost("apple/mobile")]
|
||||
public async Task<ActionResult<SnAuthChallenge>> AppleMobileLogin(
|
||||
public async Task<ActionResult<TokenExchangeResponse>> AppleMobileLogin(
|
||||
[FromBody] AppleMobileSignInRequest request
|
||||
)
|
||||
{
|
||||
@@ -98,16 +104,21 @@ public class OidcController(
|
||||
// Find or create user account using existing logic
|
||||
var account = await FindOrCreateAccount(userInfo, "apple");
|
||||
|
||||
if (HttpContext.Items["CurrentSession"] is not SnAuthSession parentSession) parentSession = null;
|
||||
|
||||
// Create session using the OIDC service
|
||||
var challenge = await appleService.CreateChallengeForUserAsync(
|
||||
var session = await appleService.CreateSessionForUserAsync(
|
||||
userInfo,
|
||||
account,
|
||||
HttpContext,
|
||||
request.DeviceId,
|
||||
request.DeviceName
|
||||
request.DeviceName,
|
||||
ClientPlatform.Ios,
|
||||
parentSession
|
||||
);
|
||||
|
||||
return Ok(challenge);
|
||||
|
||||
var token = auth.CreateToken(session);
|
||||
return Ok(new TokenExchangeResponse { Token = token });
|
||||
}
|
||||
catch (SecurityTokenValidationException ex)
|
||||
{
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
using System;
|
||||
using System.IdentityModel.Tokens.Jwt;
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
@@ -250,15 +249,17 @@ public abstract class OidcService(
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates a challenge and session for an authenticated user
|
||||
/// Creates a session for an authenticated user
|
||||
/// Also creates or updates the account connection
|
||||
/// </summary>
|
||||
public async Task<SnAuthChallenge> CreateChallengeForUserAsync(
|
||||
public async Task<SnAuthSession> CreateSessionForUserAsync(
|
||||
OidcUserInfo userInfo,
|
||||
SnAccount account,
|
||||
HttpContext request,
|
||||
string deviceId,
|
||||
string? deviceName = null
|
||||
string? deviceName = null,
|
||||
ClientPlatform platform = ClientPlatform.Web,
|
||||
SnAuthSession? parentSession = null
|
||||
)
|
||||
{
|
||||
// Create or update the account connection
|
||||
@@ -282,28 +283,24 @@ public abstract class OidcService(
|
||||
await Db.AccountConnections.AddAsync(connection);
|
||||
}
|
||||
|
||||
// Create a challenge that's already completed
|
||||
// Create a session directly
|
||||
var now = SystemClock.Instance.GetCurrentInstant();
|
||||
var device = await auth.GetOrCreateDeviceAsync(account.Id, deviceId, deviceName, ClientPlatform.Ios);
|
||||
var challenge = new SnAuthChallenge
|
||||
{
|
||||
ExpiredAt = now.Plus(Duration.FromHours(1)),
|
||||
StepTotal = await auth.DetectChallengeRisk(request.Request, account),
|
||||
Type = ChallengeType.Oidc,
|
||||
Audiences = [ProviderName],
|
||||
Scopes = ["*"],
|
||||
AccountId = account.Id,
|
||||
ClientId = device.Id,
|
||||
IpAddress = request.Connection.RemoteIpAddress?.ToString() ?? null,
|
||||
UserAgent = request.Request.Headers.UserAgent,
|
||||
};
|
||||
challenge.StepRemain--;
|
||||
if (challenge.StepRemain < 0) challenge.StepRemain = 0;
|
||||
var device = await auth.GetOrCreateDeviceAsync(account.Id, deviceId, deviceName, platform);
|
||||
|
||||
await Db.AuthChallenges.AddAsync(challenge);
|
||||
var session = new SnAuthSession
|
||||
{
|
||||
AccountId = account.Id,
|
||||
CreatedAt = now,
|
||||
LastGrantedAt = now,
|
||||
ParentSessionId = parentSession?.Id,
|
||||
ClientId = device.Id,
|
||||
ExpiredAt = now.Plus(Duration.FromDays(30))
|
||||
};
|
||||
|
||||
await Db.AuthSessions.AddAsync(session);
|
||||
await Db.SaveChangesAsync();
|
||||
|
||||
return challenge;
|
||||
return session;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ public class TokenAuthService(
|
||||
session = await db.AuthSessions
|
||||
.AsNoTracking()
|
||||
.Include(e => e.Challenge)
|
||||
.ThenInclude(e => e.Client)
|
||||
.Include(e => e.Client)
|
||||
.Include(e => e.Account)
|
||||
.ThenInclude(e => e.Profile)
|
||||
.FirstOrDefaultAsync(s => s.Id == sessionId);
|
||||
@@ -110,7 +110,7 @@ public class TokenAuthService(
|
||||
"AuthenticateTokenAsync: DB session loaded (sessionId={SessionId}, accountId={AccountId}, clientId={ClientId}, appId={AppId}, scopes={ScopeCount}, ip={Ip}, uaLen={UaLen})",
|
||||
sessionId,
|
||||
session.AccountId,
|
||||
session.Challenge?.ClientId,
|
||||
session.ClientId,
|
||||
session.AppId,
|
||||
session.Challenge?.Scopes.Count,
|
||||
session.Challenge?.IpAddress,
|
||||
@@ -143,7 +143,7 @@ public class TokenAuthService(
|
||||
"AuthenticateTokenAsync: success via DB (sessionId={SessionId}, accountId={AccountId}, clientId={ClientId})",
|
||||
sessionId,
|
||||
session.AccountId,
|
||||
session.Challenge?.ClientId
|
||||
session.ClientId
|
||||
);
|
||||
return (true, session, null);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user