From 5a0c6dc4b039441dcde768c528e9b24d25cad104 Mon Sep 17 00:00:00 2001 From: LittleSheep Date: Sat, 7 Jun 2025 18:21:51 +0800 Subject: [PATCH] :sparkles: Optimized risk detection :bug: Fix bugs --- .../Account/AccountCurrentController.cs | 38 ++++++- DysonNetwork.Sphere/Account/AccountService.cs | 20 +++- DysonNetwork.Sphere/Auth/AuthController.cs | 5 +- DysonNetwork.Sphere/Auth/AuthService.cs | 98 +++++++++++++++---- 4 files changed, 137 insertions(+), 24 deletions(-) diff --git a/DysonNetwork.Sphere/Account/AccountCurrentController.cs b/DysonNetwork.Sphere/Account/AccountCurrentController.cs index eb0eac3..f663492 100644 --- a/DysonNetwork.Sphere/Account/AccountCurrentController.cs +++ b/DysonNetwork.Sphere/Account/AccountCurrentController.cs @@ -428,7 +428,8 @@ public class AccountCurrentController( [FromQuery] int offset = 0 ) { - if (HttpContext.Items["CurrentUser"] is not Account currentUser) return Unauthorized(); + if (HttpContext.Items["CurrentUser"] is not Account currentUser || + HttpContext.Items["CurrentSession"] is not Session currentSession) return Unauthorized(); var query = db.AuthSessions .Include(session => session.Account) @@ -438,8 +439,10 @@ public class AccountCurrentController( var total = await query.CountAsync(); Response.Headers.Append("X-Total", total.ToString()); + Response.Headers.Append("X-Auth-Session", currentSession.Id.ToString()); var sessions = await query + .OrderByDescending(x => x.LastGrantedAt) .Skip(offset) .Take(take) .ToListAsync(); @@ -481,4 +484,37 @@ public class AccountCurrentController( return BadRequest(ex.Message); } } + + [HttpPatch("sessions/{id:guid}/label")] + public async Task> UpdateSessionLabel(Guid id, [FromBody] string label) + { + if (HttpContext.Items["CurrentUser"] is not Account currentUser) return Unauthorized(); + + try + { + await accounts.UpdateSessionLabel(currentUser, id, label); + return NoContent(); + } + catch (Exception ex) + { + return BadRequest(ex.Message); + } + } + + [HttpPatch("sessions/current/label")] + public async Task> UpdateCurrentSessionLabel([FromBody] string label) + { + if (HttpContext.Items["CurrentUser"] is not Account currentUser || + HttpContext.Items["CurrentSession"] is not Session currentSession) return Unauthorized(); + + try + { + await accounts.UpdateSessionLabel(currentUser, currentSession.Id, label); + return NoContent(); + } + catch (Exception ex) + { + return BadRequest(ex.Message); + } + } } \ No newline at end of file diff --git a/DysonNetwork.Sphere/Account/AccountService.cs b/DysonNetwork.Sphere/Account/AccountService.cs index abd3d8b..77c0a1c 100644 --- a/DysonNetwork.Sphere/Account/AccountService.cs +++ b/DysonNetwork.Sphere/Account/AccountService.cs @@ -293,7 +293,9 @@ public class AccountService( case AccountAuthFactorType.EmailCode: case AccountAuthFactorType.InAppCode: var correctCode = await _GetFactorCode(factor); - return correctCode is not null && string.Equals(correctCode, code, StringComparison.OrdinalIgnoreCase); + var isCorrect = correctCode is not null && string.Equals(correctCode, code, StringComparison.OrdinalIgnoreCase); + await cache.RemoveAsync($"{AuthFactorCachePrefix}{factor.Id}:code"); + return isCorrect; case AccountAuthFactorType.Password: case AccountAuthFactorType.TimedCode: default: @@ -318,6 +320,22 @@ public class AccountService( $"{AuthFactorCachePrefix}{factor.Id}:code" ); } + + public async Task UpdateSessionLabel(Account account, Guid sessionId, string label) + { + var session = await db.AuthSessions + .Include(s => s.Challenge) + .Where(s => s.Id == sessionId && s.AccountId == account.Id) + .FirstOrDefaultAsync(); + if (session is null) throw new InvalidOperationException("Session was not found."); + + session.Label = label; + await db.SaveChangesAsync(); + + await cache.RemoveAsync($"{DysonTokenAuthHandler.AuthCachePrefix}{session.Id}"); + + return session; + } public async Task DeleteSession(Account account, Guid sessionId) { diff --git a/DysonNetwork.Sphere/Auth/AuthController.cs b/DysonNetwork.Sphere/Auth/AuthController.cs index 0f1702e..e252fd2 100644 --- a/DysonNetwork.Sphere/Auth/AuthController.cs +++ b/DysonNetwork.Sphere/Auth/AuthController.cs @@ -53,7 +53,7 @@ public class AuthController( var challenge = new Challenge { ExpiredAt = Instant.FromDateTimeUtc(DateTime.UtcNow.AddHours(1)), - StepTotal = 3, + StepTotal = await auth.DetectChallengeRisk(Request, account), Platform = request.Platform, Audiences = request.Audiences, Scopes = request.Scopes, @@ -205,7 +205,6 @@ public class AuthController( [HttpPost("token")] public async Task> ExchangeToken([FromBody] TokenExchangeRequest request) { - Session? session; switch (request.GrantType) { case "authorization_code": @@ -221,7 +220,7 @@ public class AuthController( if (challenge.StepRemain != 0) return BadRequest("Challenge not yet completed."); - session = await db.AuthSessions + var session = await db.AuthSessions .Where(e => e.Challenge == challenge) .FirstOrDefaultAsync(); if (session is not null) diff --git a/DysonNetwork.Sphere/Auth/AuthService.cs b/DysonNetwork.Sphere/Auth/AuthService.cs index 3d5eb0d..e60eeb3 100644 --- a/DysonNetwork.Sphere/Auth/AuthService.cs +++ b/DysonNetwork.Sphere/Auth/AuthService.cs @@ -1,19 +1,79 @@ using System.Security.Cryptography; using System.Text.Json; +using Microsoft.EntityFrameworkCore; +using NodaTime; namespace DysonNetwork.Sphere.Auth; -public class AuthService(IConfiguration config, IHttpClientFactory httpClientFactory) +public class AuthService(AppDatabase db, IConfiguration config, IHttpClientFactory httpClientFactory) { + /// + /// Detect the risk of the current request to login + /// and returns the required steps to login. + /// + /// The request context + /// The account to login + /// The required steps to login + public async Task DetectChallengeRisk(HttpRequest request, Account.Account account) + { + // 1) Find out how many authentication factors the account has enabled. + var maxSteps = await db.AccountAuthFactors + .Where(f => f.AccountId == account.Id) + .Where(f => f.EnabledAt != null) + .CountAsync(); + + // We’ll accumulate a “risk score” based on various factors. + // Then we can decide how many total steps are required for the challenge. + var riskScore = 0; + + // 2) Get the remote IP address from the request (if any). + var ipAddress = request.HttpContext.Connection.RemoteIpAddress?.ToString(); + var lastActiveInfo = await db.AuthSessions + .OrderByDescending(s => s.LastGrantedAt) + .Include(s => s.Challenge) + .Where(s => s.AccountId == account.Id) + .FirstOrDefaultAsync(); + + // Example check: if IP is missing or in an unusual range, increase the risk. + // (This is just a placeholder; in reality, you’d integrate with GeoIpService or a custom check.) + if (string.IsNullOrWhiteSpace(ipAddress)) + riskScore += 1; + else + { + if (!string.IsNullOrEmpty(lastActiveInfo?.Challenge.IpAddress) && + !lastActiveInfo.Challenge.IpAddress.Equals(ipAddress, StringComparison.OrdinalIgnoreCase)) + riskScore += 1; + } + + // 3) (Optional) Check how recent the last login was. + // If it was a long time ago, the risk might be higher. + var now = SystemClock.Instance.GetCurrentInstant(); + var daysSinceLastActive = lastActiveInfo?.LastGrantedAt is not null + ? (now - lastActiveInfo.LastGrantedAt.Value).TotalDays + : double.MaxValue; + if (daysSinceLastActive > 30) + riskScore += 1; + + // 4) Combine base “maxSteps” (the number of enabled factors) with any accumulated risk score. + // You might choose to make “maxSteps + riskScore” your total required steps, + // or clamp it to maxSteps if you only want to require existing available factors. + var totalRequiredSteps = maxSteps + riskScore; + + // Clamp the step + totalRequiredSteps = Math.Max(Math.Min(totalRequiredSteps, maxSteps), 1); + + return totalRequiredSteps; + } + public async Task ValidateCaptcha(string token) { if (string.IsNullOrWhiteSpace(token)) return false; - + var provider = config.GetSection("Captcha")["Provider"]?.ToLower(); var apiSecret = config.GetSection("Captcha")["ApiSecret"]; var client = httpClientFactory.CreateClient(); - + var jsonOpts = new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower, @@ -28,7 +88,7 @@ public class AuthService(IConfiguration config, IHttpClientFactory httpClientFac var response = await client.PostAsync("https://challenges.cloudflare.com/turnstile/v0/siteverify", content); response.EnsureSuccessStatusCode(); - + var json = await response.Content.ReadAsStringAsync(); var result = JsonSerializer.Deserialize(json, options: jsonOpts); @@ -64,7 +124,7 @@ public class AuthService(IConfiguration config, IHttpClientFactory httpClientFac var privateKeyPem = File.ReadAllText(config["Jwt:PrivateKeyPath"]!); using var rsa = RSA.Create(); rsa.ImportFromPem(privateKeyPem); - + // Create and return a single token return CreateCompactToken(session.Id, rsa); } @@ -73,42 +133,42 @@ public class AuthService(IConfiguration config, IHttpClientFactory httpClientFac { // Create the payload: just the session ID var payloadBytes = sessionId.ToByteArray(); - + // Base64Url encode the payload var payloadBase64 = Base64UrlEncode(payloadBytes); - + // Sign the payload with RSA-SHA256 var signature = rsa.SignData(payloadBytes, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); - + // Base64Url encode the signature var signatureBase64 = Base64UrlEncode(signature); - + // Combine payload and signature with a dot return $"{payloadBase64}.{signatureBase64}"; } - + public bool ValidateToken(string token, out Guid sessionId) { sessionId = Guid.Empty; - + try { // Split the token var parts = token.Split('.'); if (parts.Length != 2) return false; - + // Decode the payload var payloadBytes = Base64UrlDecode(parts[0]); - + // Extract session ID sessionId = new Guid(payloadBytes); - + // Load public key for verification var publicKeyPem = File.ReadAllText(config["Jwt:PublicKeyPath"]!); using var rsa = RSA.Create(); rsa.ImportFromPem(publicKeyPem); - + // Verify signature var signature = Base64UrlDecode(parts[1]); return rsa.VerifyData(payloadBytes, signature, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); @@ -118,7 +178,7 @@ public class AuthService(IConfiguration config, IHttpClientFactory httpClientFac return false; } } - + // Helper methods for Base64Url encoding/decoding private static string Base64UrlEncode(byte[] data) { @@ -127,19 +187,19 @@ public class AuthService(IConfiguration config, IHttpClientFactory httpClientFac .Replace('+', '-') .Replace('/', '_'); } - + private static byte[] Base64UrlDecode(string base64Url) { string padded = base64Url .Replace('-', '+') .Replace('_', '/'); - + switch (padded.Length % 4) { case 2: padded += "=="; break; case 3: padded += "="; break; } - + return Convert.FromBase64String(padded); } } \ No newline at end of file