✨ Optimized risk detection
🐛 Fix bugs
			
			
This commit is contained in:
		| @@ -53,7 +53,7 @@ public class AuthController( | ||||
|         var challenge = new Challenge | ||||
|         { | ||||
|             ExpiredAt = Instant.FromDateTimeUtc(DateTime.UtcNow.AddHours(1)), | ||||
|             StepTotal = 3, | ||||
|             StepTotal = await auth.DetectChallengeRisk(Request, account), | ||||
|             Platform = request.Platform, | ||||
|             Audiences = request.Audiences, | ||||
|             Scopes = request.Scopes, | ||||
| @@ -205,7 +205,6 @@ public class AuthController( | ||||
|     [HttpPost("token")] | ||||
|     public async Task<ActionResult<TokenExchangeResponse>> ExchangeToken([FromBody] TokenExchangeRequest request) | ||||
|     { | ||||
|         Session? session; | ||||
|         switch (request.GrantType) | ||||
|         { | ||||
|             case "authorization_code": | ||||
| @@ -221,7 +220,7 @@ public class AuthController( | ||||
|                 if (challenge.StepRemain != 0) | ||||
|                     return BadRequest("Challenge not yet completed."); | ||||
|  | ||||
|                 session = await db.AuthSessions | ||||
|                 var session = await db.AuthSessions | ||||
|                     .Where(e => e.Challenge == challenge) | ||||
|                     .FirstOrDefaultAsync(); | ||||
|                 if (session is not null) | ||||
|   | ||||
| @@ -1,19 +1,79 @@ | ||||
| using System.Security.Cryptography; | ||||
| using System.Text.Json; | ||||
| using Microsoft.EntityFrameworkCore; | ||||
| using NodaTime; | ||||
|  | ||||
| namespace DysonNetwork.Sphere.Auth; | ||||
|  | ||||
| public class AuthService(IConfiguration config, IHttpClientFactory httpClientFactory) | ||||
| public class AuthService(AppDatabase db, IConfiguration config, IHttpClientFactory httpClientFactory) | ||||
| { | ||||
|     /// <summary> | ||||
|     /// Detect the risk of the current request to login | ||||
|     /// and returns the required steps to login. | ||||
|     /// </summary> | ||||
|     /// <param name="request">The request context</param> | ||||
|     /// <param name="account">The account to login</param> | ||||
|     /// <returns>The required steps to login</returns> | ||||
|     public async Task<int> DetectChallengeRisk(HttpRequest request, Account.Account account) | ||||
|     { | ||||
|         // 1) Find out how many authentication factors the account has enabled. | ||||
|         var maxSteps = await db.AccountAuthFactors | ||||
|             .Where(f => f.AccountId == account.Id) | ||||
|             .Where(f => f.EnabledAt != null) | ||||
|             .CountAsync(); | ||||
|  | ||||
|         // We’ll accumulate a “risk score” based on various factors. | ||||
|         // Then we can decide how many total steps are required for the challenge. | ||||
|         var riskScore = 0; | ||||
|  | ||||
|         // 2) Get the remote IP address from the request (if any). | ||||
|         var ipAddress = request.HttpContext.Connection.RemoteIpAddress?.ToString(); | ||||
|         var lastActiveInfo = await db.AuthSessions | ||||
|             .OrderByDescending(s => s.LastGrantedAt) | ||||
|             .Include(s => s.Challenge) | ||||
|             .Where(s => s.AccountId == account.Id) | ||||
|             .FirstOrDefaultAsync(); | ||||
|  | ||||
|         // Example check: if IP is missing or in an unusual range, increase the risk. | ||||
|         // (This is just a placeholder; in reality, you’d integrate with GeoIpService or a custom check.) | ||||
|         if (string.IsNullOrWhiteSpace(ipAddress)) | ||||
|             riskScore += 1; | ||||
|         else | ||||
|         { | ||||
|             if (!string.IsNullOrEmpty(lastActiveInfo?.Challenge.IpAddress) && | ||||
|                 !lastActiveInfo.Challenge.IpAddress.Equals(ipAddress, StringComparison.OrdinalIgnoreCase)) | ||||
|                 riskScore += 1; | ||||
|         } | ||||
|  | ||||
|         // 3) (Optional) Check how recent the last login was. | ||||
|         // If it was a long time ago, the risk might be higher. | ||||
|         var now = SystemClock.Instance.GetCurrentInstant(); | ||||
|         var daysSinceLastActive = lastActiveInfo?.LastGrantedAt is not null | ||||
|             ? (now - lastActiveInfo.LastGrantedAt.Value).TotalDays | ||||
|             : double.MaxValue; | ||||
|         if (daysSinceLastActive > 30) | ||||
|             riskScore += 1; | ||||
|  | ||||
|         // 4) Combine base “maxSteps” (the number of enabled factors) with any accumulated risk score. | ||||
|         // You might choose to make “maxSteps + riskScore” your total required steps, | ||||
|         // or clamp it to maxSteps if you only want to require existing available factors. | ||||
|         var totalRequiredSteps = maxSteps + riskScore; | ||||
|  | ||||
|         // Clamp the step | ||||
|         totalRequiredSteps = Math.Max(Math.Min(totalRequiredSteps, maxSteps), 1); | ||||
|  | ||||
|         return totalRequiredSteps; | ||||
|     } | ||||
|  | ||||
|     public async Task<bool> ValidateCaptcha(string token) | ||||
|     { | ||||
|         if (string.IsNullOrWhiteSpace(token)) return false; | ||||
|          | ||||
|  | ||||
|         var provider = config.GetSection("Captcha")["Provider"]?.ToLower(); | ||||
|         var apiSecret = config.GetSection("Captcha")["ApiSecret"]; | ||||
|  | ||||
|         var client = httpClientFactory.CreateClient(); | ||||
|          | ||||
|  | ||||
|         var jsonOpts = new JsonSerializerOptions | ||||
|         { | ||||
|             PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower, | ||||
| @@ -28,7 +88,7 @@ public class AuthService(IConfiguration config, IHttpClientFactory httpClientFac | ||||
|                 var response = await client.PostAsync("https://challenges.cloudflare.com/turnstile/v0/siteverify", | ||||
|                     content); | ||||
|                 response.EnsureSuccessStatusCode(); | ||||
|                  | ||||
|  | ||||
|                 var json = await response.Content.ReadAsStringAsync(); | ||||
|                 var result = JsonSerializer.Deserialize<CaptchaVerificationResponse>(json, options: jsonOpts); | ||||
|  | ||||
| @@ -64,7 +124,7 @@ public class AuthService(IConfiguration config, IHttpClientFactory httpClientFac | ||||
|         var privateKeyPem = File.ReadAllText(config["Jwt:PrivateKeyPath"]!); | ||||
|         using var rsa = RSA.Create(); | ||||
|         rsa.ImportFromPem(privateKeyPem); | ||||
|          | ||||
|  | ||||
|         // Create and return a single token | ||||
|         return CreateCompactToken(session.Id, rsa); | ||||
|     } | ||||
| @@ -73,42 +133,42 @@ public class AuthService(IConfiguration config, IHttpClientFactory httpClientFac | ||||
|     { | ||||
|         // Create the payload: just the session ID | ||||
|         var payloadBytes = sessionId.ToByteArray(); | ||||
|          | ||||
|  | ||||
|         // Base64Url encode the payload | ||||
|         var payloadBase64 = Base64UrlEncode(payloadBytes); | ||||
|          | ||||
|  | ||||
|         // Sign the payload with RSA-SHA256 | ||||
|         var signature = rsa.SignData(payloadBytes, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); | ||||
|          | ||||
|  | ||||
|         // Base64Url encode the signature | ||||
|         var signatureBase64 = Base64UrlEncode(signature); | ||||
|          | ||||
|  | ||||
|         // Combine payload and signature with a dot | ||||
|         return $"{payloadBase64}.{signatureBase64}"; | ||||
|     } | ||||
|      | ||||
|  | ||||
|     public bool ValidateToken(string token, out Guid sessionId) | ||||
|     { | ||||
|         sessionId = Guid.Empty; | ||||
|          | ||||
|  | ||||
|         try | ||||
|         { | ||||
|             // Split the token | ||||
|             var parts = token.Split('.'); | ||||
|             if (parts.Length != 2) | ||||
|                 return false; | ||||
|              | ||||
|  | ||||
|             // Decode the payload | ||||
|             var payloadBytes = Base64UrlDecode(parts[0]); | ||||
|              | ||||
|  | ||||
|             // Extract session ID | ||||
|             sessionId = new Guid(payloadBytes); | ||||
|              | ||||
|  | ||||
|             // Load public key for verification | ||||
|             var publicKeyPem = File.ReadAllText(config["Jwt:PublicKeyPath"]!); | ||||
|             using var rsa = RSA.Create(); | ||||
|             rsa.ImportFromPem(publicKeyPem); | ||||
|              | ||||
|  | ||||
|             // Verify signature | ||||
|             var signature = Base64UrlDecode(parts[1]); | ||||
|             return rsa.VerifyData(payloadBytes, signature, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); | ||||
| @@ -118,7 +178,7 @@ public class AuthService(IConfiguration config, IHttpClientFactory httpClientFac | ||||
|             return false; | ||||
|         } | ||||
|     } | ||||
|      | ||||
|  | ||||
|     // Helper methods for Base64Url encoding/decoding | ||||
|     private static string Base64UrlEncode(byte[] data) | ||||
|     { | ||||
| @@ -127,19 +187,19 @@ public class AuthService(IConfiguration config, IHttpClientFactory httpClientFac | ||||
|             .Replace('+', '-') | ||||
|             .Replace('/', '_'); | ||||
|     } | ||||
|      | ||||
|  | ||||
|     private static byte[] Base64UrlDecode(string base64Url) | ||||
|     { | ||||
|         string padded = base64Url | ||||
|             .Replace('-', '+') | ||||
|             .Replace('_', '/'); | ||||
|              | ||||
|  | ||||
|         switch (padded.Length % 4) | ||||
|         { | ||||
|             case 2: padded += "=="; break; | ||||
|             case 3: padded += "="; break; | ||||
|         } | ||||
|          | ||||
|  | ||||
|         return Convert.FromBase64String(padded); | ||||
|     } | ||||
| } | ||||
		Reference in New Issue
	
	Block a user