♻️ Refactored oidc

This commit is contained in:
LittleSheep 2025-06-29 11:53:44 +08:00
parent d4fa08d320
commit f8295c6a18
4 changed files with 175 additions and 252 deletions

View File

@ -1,8 +1,5 @@
using System.ComponentModel.DataAnnotations;
using System.Security.Claims;
using System.Security.Cryptography; using System.Security.Cryptography;
using System.Text; using System.Text;
using DysonNetwork.Sphere.Developer;
using DysonNetwork.Sphere.Auth.OidcProvider.Options; using DysonNetwork.Sphere.Auth.OidcProvider.Options;
using DysonNetwork.Sphere.Auth.OidcProvider.Responses; using DysonNetwork.Sphere.Auth.OidcProvider.Responses;
using DysonNetwork.Sphere.Auth.OidcProvider.Services; using DysonNetwork.Sphere.Auth.OidcProvider.Services;
@ -10,210 +7,104 @@ using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using DysonNetwork.Sphere.Account;
using Microsoft.EntityFrameworkCore;
namespace DysonNetwork.Sphere.Auth.OidcProvider.Controllers; namespace DysonNetwork.Sphere.Auth.OidcProvider.Controllers;
[Route("connect")] [Route("/auth/open")]
[ApiController] [ApiController]
public class OidcProviderController( public class OidcProviderController(
AppDatabase db,
OidcProviderService oidcService, OidcProviderService oidcService,
IConfiguration configuration,
IOptions<OidcProviderOptions> options, IOptions<OidcProviderOptions> options,
ILogger<OidcProviderController> logger ILogger<OidcProviderController> logger
) )
: ControllerBase : ControllerBase
{ {
[HttpGet("authorize")]
public async Task<IActionResult> Authorize(
[Required][FromQuery(Name = "client_id")] Guid clientId,
[Required][FromQuery(Name = "response_type")] string responseType,
[FromQuery(Name = "redirect_uri")] string? redirectUri,
[FromQuery] string? scope,
[FromQuery] string? state,
[FromQuery] string? nonce,
[FromQuery(Name = "code_challenge")] string? codeChallenge,
[FromQuery(Name = "code_challenge_method")] string? codeChallengeMethod,
[FromQuery(Name = "response_mode")] string? responseMode
)
{
// Check if user is authenticated
if (HttpContext.Items["CurrentUser"] is not Account.Account currentUser)
{
// Not authenticated - redirect to login with return URL
var loginUrl = "/Auth/Login";
var returnUrl = $"{Request.Path}{Request.QueryString}";
return Redirect($"{loginUrl}?returnUrl={Uri.EscapeDataString(returnUrl)}");
}
// Validate client
var client = await oidcService.FindClientByIdAsync(clientId);
if (client == null)
return BadRequest(new ErrorResponse { Error = "invalid_client", ErrorDescription = "Client not found" });
// Check if user has already granted permission to this client
// For now, we'll always show the consent page. In a real app, you might store consent decisions.
// If you want to implement "remember my decision", you would check that here.
var consentRequired = true;
if (consentRequired)
{
// Redirect to consent page with all the OAuth parameters
var consentUrl = $"/Auth/Authorize?client_id={clientId}";
if (!string.IsNullOrEmpty(responseType)) consentUrl += $"&response_type={Uri.EscapeDataString(responseType)}";
if (!string.IsNullOrEmpty(redirectUri)) consentUrl += $"&redirect_uri={Uri.EscapeDataString(redirectUri)}";
if (!string.IsNullOrEmpty(scope)) consentUrl += $"&scope={Uri.EscapeDataString(scope)}";
if (!string.IsNullOrEmpty(state)) consentUrl += $"&state={Uri.EscapeDataString(state)}";
if (!string.IsNullOrEmpty(nonce)) consentUrl += $"&nonce={Uri.EscapeDataString(nonce)}";
if (!string.IsNullOrEmpty(codeChallenge)) consentUrl += $"&code_challenge={Uri.EscapeDataString(codeChallenge)}";
if (!string.IsNullOrEmpty(codeChallengeMethod)) consentUrl += $"&code_challenge_method={Uri.EscapeDataString(codeChallengeMethod)}";
if (!string.IsNullOrEmpty(responseMode)) consentUrl += $"&response_mode={Uri.EscapeDataString(responseMode)}";
return Redirect(consentUrl);
}
// Skip redirect_uri validation for apps in Developing status
if (client.Status != CustomAppStatus.Developing)
{
// Validate redirect URI for non-Developing apps
if (!string.IsNullOrEmpty(redirectUri) && !(client.RedirectUris?.Contains(redirectUri) ?? false))
return BadRequest(
new ErrorResponse { Error = "invalid_request", ErrorDescription = "Invalid redirect_uri" });
}
else
{
logger.LogWarning("Skipping redirect_uri validation for app {AppId} in Developing status", clientId);
// If no redirect_uri is provided and we're in development, use the first one
if (string.IsNullOrEmpty(redirectUri) && client.RedirectUris?.Any() == true)
{
redirectUri = client.RedirectUris.First();
}
}
// Generate authorization code
var code = Guid.NewGuid().ToString("N");
var userId = User.FindFirstValue(ClaimTypes.NameIdentifier);
// In a real implementation, you'd store this code with the user's consent and requested scopes
// and validate it in the token endpoint
// For now, we'll just return the code directly (simplified for example)
var response = new AuthorizationResponse
{
Code = code,
State = state,
Scope = scope,
Issuer = options.Value.IssuerUri
};
// Redirect back to the client with the authorization code
var finalRedirectUri = new UriBuilder(redirectUri ?? client.RedirectUris?.First() ?? throw new InvalidOperationException("No redirect URI provided and no default redirect URI found"));
var query = System.Web.HttpUtility.ParseQueryString(finalRedirectUri.Query);
query["code"] = response.Code;
if (!string.IsNullOrEmpty(response.State))
query["state"] = response.State;
if (!string.IsNullOrEmpty(response.Scope))
query["scope"] = response.Scope;
finalRedirectUri.Query = query.ToString();
return Redirect(finalRedirectUri.Uri.ToString());
}
[HttpPost("token")] [HttpPost("token")]
[Consumes("application/x-www-form-urlencoded")] [Consumes("application/x-www-form-urlencoded")]
public async Task<IActionResult> Token([FromForm] TokenRequest request) public async Task<IActionResult> Token([FromForm] TokenRequest request)
{ {
if (request.GrantType == "authorization_code") switch (request.GrantType)
{ {
// Validate client credentials // Validate client credentials
if (request.ClientId == null || string.IsNullOrEmpty(request.ClientSecret)) case "authorization_code" when request.ClientId == null || string.IsNullOrEmpty(request.ClientSecret):
return BadRequest(new ErrorResponse { Error = "invalid_client", ErrorDescription = "Client credentials are required" }); return BadRequest("Client credentials are required");
case "authorization_code" when request.Code == null:
return BadRequest("Authorization code is required");
case "authorization_code":
{
var client = await oidcService.FindClientByIdAsync(request.ClientId.Value); var client = await oidcService.FindClientByIdAsync(request.ClientId.Value);
if (client == null || !await oidcService.ValidateClientCredentialsAsync(request.ClientId.Value, request.ClientSecret)) if (client == null ||
return BadRequest(new ErrorResponse { Error = "invalid_client", ErrorDescription = "Invalid client credentials" }); !await oidcService.ValidateClientCredentialsAsync(request.ClientId.Value, request.ClientSecret))
return BadRequest(new ErrorResponse
{ Error = "invalid_client", ErrorDescription = "Invalid client credentials" });
// Validate the authorization code // Validate the authorization code
var authCode = await oidcService.ValidateAuthorizationCodeAsync( var authCode = await oidcService.ValidateAuthorizationCodeAsync(
request.Code ?? string.Empty, request.Code ?? string.Empty,
request.ClientId.Value, request.ClientId.Value,
request.RedirectUri, request.RedirectUri,
request.CodeVerifier); request.CodeVerifier
);
if (authCode == null) if (authCode == null)
{ {
logger.LogWarning("Invalid or expired authorization code: {Code}", request.Code); logger.LogWarning(@"Invalid or expired authorization code: {Code}", request.Code);
return BadRequest(new ErrorResponse { Error = "invalid_grant", ErrorDescription = "Invalid or expired authorization code" }); return BadRequest(new ErrorResponse
{ Error = "invalid_grant", ErrorDescription = "Invalid or expired authorization code" });
} }
// Generate tokens // Generate tokens
var tokenResponse = await oidcService.GenerateTokenResponseAsync( var tokenResponse = await oidcService.GenerateTokenResponseAsync(
clientId: request.ClientId.Value, clientId: request.ClientId.Value,
subjectId: authCode.UserId,
scopes: authCode.Scopes, scopes: authCode.Scopes,
authorizationCode: request.Code); authorizationCode: request.Code!
);
return Ok(tokenResponse); return Ok(tokenResponse);
} }
else if (request.GrantType == "refresh_token") case "refresh_token":
{
// Handle refresh token request // Handle refresh token request
// In a real implementation, you would validate the refresh token // In a real implementation, you would validate the refresh token
// and issue a new access token // and issue a new access token
return BadRequest(new ErrorResponse { Error = "unsupported_grant_type" }); return BadRequest(new ErrorResponse { Error = "unsupported_grant_type" });
} default:
return BadRequest(new ErrorResponse { Error = "unsupported_grant_type" }); return BadRequest(new ErrorResponse { Error = "unsupported_grant_type" });
} }
}
[HttpGet("userinfo")] [HttpGet("userinfo")]
[Authorize(AuthenticationSchemes = "Bearer")] [Authorize]
public async Task<IActionResult> UserInfo() public async Task<IActionResult> UserInfo()
{ {
var authHeader = HttpContext.Request.Headers.Authorization.ToString(); if (HttpContext.Items["CurrentUser"] is not Account.Account currentUser ||
if (string.IsNullOrEmpty(authHeader) || !authHeader.StartsWith("Bearer ")) HttpContext.Items["CurrentSession"] is not Session currentSession) return Unauthorized();
{
var loginUrl = "/Account/Login"; // Update this path to your actual login page path
var returnUrl = $"{Request.Scheme}://{Request.Host}{Request.Path}{Request.QueryString}";
return Redirect($"{loginUrl}?returnUrl={Uri.EscapeDataString(returnUrl)}");
}
var token = authHeader["Bearer ".Length..].Trim();
var jwtToken = oidcService.ValidateToken(token);
if (jwtToken == null)
{
var loginUrl = "/Account/Login"; // Update this path to your actual login page path
var returnUrl = $"{Request.Scheme}://{Request.Host}{Request.Path}{Request.QueryString}";
return Redirect($"{loginUrl}?returnUrl={Uri.EscapeDataString(returnUrl)}");
}
// Get user info based on the subject claim from the token
var userId = User.FindFirstValue(ClaimTypes.NameIdentifier);
var userName = User.FindFirstValue(ClaimTypes.Name);
var userEmail = User.FindFirstValue(ClaimTypes.Email);
// Get requested scopes from the token // Get requested scopes from the token
var scopes = jwtToken.Claims var scopes = currentSession.Challenge.Scopes;
.Where(c => c.Type == "scope")
.SelectMany(c => c.Value.Split(' '))
.ToHashSet();
var userInfo = new Dictionary<string, object> var userInfo = new Dictionary<string, object>
{ {
["sub"] = userId ?? "anonymous" ["sub"] = currentUser.Id
}; };
// Include standard claims based on scopes // Include standard claims based on scopes
if (scopes.Contains("profile") || scopes.Contains("name")) if (scopes.Contains("profile") || scopes.Contains("name"))
{ {
if (!string.IsNullOrEmpty(userName)) userInfo["name"] = currentUser.Name;
userInfo["name"] = userName; userInfo["preferred_username"] = currentUser.Nick;
} }
if (scopes.Contains("email") && !string.IsNullOrEmpty(userEmail)) var userEmail = await db.AccountContacts
.Where(c => c.Type == AccountContactType.Email && c.AccountId == currentUser.Id)
.FirstOrDefaultAsync();
if (scopes.Contains("email") && userEmail is not null)
{ {
userInfo["email"] = userEmail; userInfo["email"] = userEmail.Content;
userInfo["email_verified"] = true; // In a real app, check if email is verified userInfo["email_verified"] = userEmail.VerifiedAt is not null;
} }
return Ok(userInfo); return Ok(userInfo);
@ -222,18 +113,19 @@ public class OidcProviderController(
[HttpGet(".well-known/openid-configuration")] [HttpGet(".well-known/openid-configuration")]
public IActionResult GetConfiguration() public IActionResult GetConfiguration()
{ {
var baseUrl = $"{Request.Scheme}://{Request.Host}{Request.PathBase}".TrimEnd('/'); var baseUrl = configuration["BaseUrl"];
var issuer = options.Value.IssuerUri.TrimEnd('/'); var issuer = options.Value.IssuerUri.TrimEnd('/');
return Ok(new return Ok(new
{ {
issuer = issuer, issuer = issuer,
authorization_endpoint = $"{baseUrl}/connect/authorize", authorization_endpoint = $"{baseUrl}/connect/authorize",
token_endpoint = $"{baseUrl}/connect/token", token_endpoint = $"{baseUrl}/auth/open/token",
userinfo_endpoint = $"{baseUrl}/connect/userinfo", userinfo_endpoint = $"{baseUrl}/auth/open/userinfo",
jwks_uri = $"{baseUrl}/.well-known/openid-configuration/jwks", jwks_uri = $"{baseUrl}/.well-known/openid-configuration/jwks",
scopes_supported = new[] { "openid", "profile", "email" }, scopes_supported = new[] { "openid", "profile", "email" },
response_types_supported = new[] { "code", "token", "id_token", "code token", "code id_token", "token id_token", "code token id_token" }, response_types_supported = new[]
{ "code", "token", "id_token", "code token", "code id_token", "token id_token", "code token id_token" },
grant_types_supported = new[] { "authorization_code", "refresh_token" }, grant_types_supported = new[] { "authorization_code", "refresh_token" },
token_endpoint_auth_methods_supported = new[] { "client_secret_basic", "client_secret_post" }, token_endpoint_auth_methods_supported = new[] { "client_secret_basic", "client_secret_post" },
id_token_signing_alg_values_supported = new[] { "HS256" }, id_token_signing_alg_values_supported = new[] { "HS256" },
@ -248,12 +140,8 @@ public class OidcProviderController(
} }
[HttpGet("jwks")] [HttpGet("jwks")]
public IActionResult Jwks() public IActionResult GetJwks()
{ {
// In a production environment, you should use asymmetric keys (RSA or EC)
// and expose only the public key here. This is a simplified example using HMAC.
// For production, consider using RSA or EC keys and proper key rotation.
var keyBytes = Encoding.UTF8.GetBytes(options.Value.SigningKey); var keyBytes = Encoding.UTF8.GetBytes(options.Value.SigningKey);
var keyId = Convert.ToBase64String(SHA256.HashData(keyBytes)[..8]) var keyId = Convert.ToBase64String(SHA256.HashData(keyBytes)[..8])
.Replace("+", "-") .Replace("+", "-")
@ -279,27 +167,19 @@ public class OidcProviderController(
public class TokenRequest public class TokenRequest
{ {
[JsonPropertyName("grant_type")] [JsonPropertyName("grant_type")] public string? GrantType { get; set; }
public string? GrantType { get; set; }
[JsonPropertyName("code")] [JsonPropertyName("code")] public string? Code { get; set; }
public string? Code { get; set; }
[JsonPropertyName("redirect_uri")] [JsonPropertyName("redirect_uri")] public string? RedirectUri { get; set; }
public string? RedirectUri { get; set; }
[JsonPropertyName("client_id")] [JsonPropertyName("client_id")] public Guid? ClientId { get; set; }
public Guid? ClientId { get; set; }
[JsonPropertyName("client_secret")] [JsonPropertyName("client_secret")] public string? ClientSecret { get; set; }
public string? ClientSecret { get; set; }
[JsonPropertyName("refresh_token")] [JsonPropertyName("refresh_token")] public string? RefreshToken { get; set; }
public string? RefreshToken { get; set; }
[JsonPropertyName("scope")] [JsonPropertyName("scope")] public string? Scope { get; set; }
public string? Scope { get; set; }
[JsonPropertyName("code_verifier")] [JsonPropertyName("code_verifier")] public string? CodeVerifier { get; set; }
public string? CodeVerifier { get; set; }
} }

View File

@ -7,7 +7,7 @@ namespace DysonNetwork.Sphere.Auth.OidcProvider.Models;
public class AuthorizationCodeInfo public class AuthorizationCodeInfo
{ {
public Guid ClientId { get; set; } public Guid ClientId { get; set; }
public string UserId { get; set; } = string.Empty; public Guid AccountId { get; set; }
public string RedirectUri { get; set; } = string.Empty; public string RedirectUri { get; set; } = string.Empty;
public List<string> Scopes { get; set; } = new(); public List<string> Scopes { get; set; } = new();
public string? CodeChallenge { get; set; } public string? CodeChallenge { get; set; }

View File

@ -2,13 +2,11 @@ using System.IdentityModel.Tokens.Jwt;
using System.Security.Claims; using System.Security.Claims;
using System.Security.Cryptography; using System.Security.Cryptography;
using System.Text; using System.Text;
using System.Text.Encodings.Web;
using System.Text.Json;
using DysonNetwork.Sphere.Auth.OidcProvider.Models; using DysonNetwork.Sphere.Auth.OidcProvider.Models;
using DysonNetwork.Sphere.Auth.OidcProvider.Options; using DysonNetwork.Sphere.Auth.OidcProvider.Options;
using DysonNetwork.Sphere.Auth.OidcProvider.Responses; using DysonNetwork.Sphere.Auth.OidcProvider.Responses;
using DysonNetwork.Sphere.Developer; using DysonNetwork.Sphere.Developer;
using Microsoft.AspNetCore.Identity; using DysonNetwork.Sphere.Storage;
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using Microsoft.IdentityModel.Tokens; using Microsoft.IdentityModel.Tokens;
@ -18,7 +16,8 @@ namespace DysonNetwork.Sphere.Auth.OidcProvider.Services;
public class OidcProviderService( public class OidcProviderService(
AppDatabase db, AppDatabase db,
IClock clock, AuthService auth,
ICacheService cache,
IOptions<OidcProviderOptions> options, IOptions<OidcProviderOptions> options,
ILogger<OidcProviderService> logger ILogger<OidcProviderService> logger
) )
@ -44,6 +43,7 @@ public class OidcProviderService(
var client = await FindClientByIdAsync(clientId); var client = await FindClientByIdAsync(clientId);
if (client == null) return false; if (client == null) return false;
var clock = SystemClock.Instance;
var secret = client.Secrets var secret = client.Secrets
.Where(s => s.IsOidc && (s.ExpiredAt == null || s.ExpiredAt > clock.GetCurrentInstant())) .Where(s => s.IsOidc && (s.ExpiredAt == null || s.ExpiredAt > clock.GetCurrentInstant()))
.FirstOrDefault(s => s.Secret == clientSecret); // In production, use proper hashing .FirstOrDefault(s => s.Secret == clientSecret); // In production, use proper hashing
@ -53,20 +53,28 @@ public class OidcProviderService(
public async Task<TokenResponse> GenerateTokenResponseAsync( public async Task<TokenResponse> GenerateTokenResponseAsync(
Guid clientId, Guid clientId,
string subjectId, string authorizationCode,
IEnumerable<string>? scopes = null, IEnumerable<string>? scopes = null
string? authorizationCode = null) )
{ {
var client = await FindClientByIdAsync(clientId); var client = await FindClientByIdAsync(clientId);
if (client == null) if (client == null)
throw new InvalidOperationException("Client not found"); throw new InvalidOperationException("Client not found");
var authCode = await ValidateAuthorizationCodeAsync(authorizationCode, clientId);
if (authCode is null) throw new InvalidOperationException("Invalid authorization code");
var account = await db.Accounts.Where(a => a.Id == authCode.AccountId).FirstOrDefaultAsync();
if (account is null) throw new InvalidOperationException("Account was not found");
var clock = SystemClock.Instance;
var now = clock.GetCurrentInstant(); var now = clock.GetCurrentInstant();
var session = await auth.CreateSessionAsync(account, now);
var expiresIn = (int)_options.AccessTokenLifetime.TotalSeconds; var expiresIn = (int)_options.AccessTokenLifetime.TotalSeconds;
var expiresAt = now.Plus(Duration.FromSeconds(expiresIn)); var expiresAt = now.Plus(Duration.FromSeconds(expiresIn));
// Generate access token // Generate access token
var accessToken = GenerateJwtToken(client, subjectId, expiresAt, scopes); var accessToken = GenerateJwtToken(client, session, expiresAt, scopes);
var refreshToken = GenerateRefreshToken(); var refreshToken = GenerateRefreshToken();
// In a real implementation, you would store the token in the database // In a real implementation, you would store the token in the database
@ -83,27 +91,33 @@ public class OidcProviderService(
}; };
} }
private string GenerateJwtToken(CustomApp client, string subjectId, Instant expiresAt, IEnumerable<string>? scopes = null) private string GenerateJwtToken(
CustomApp client,
Session session,
Instant expiresAt,
IEnumerable<string>? scopes = null
)
{ {
var tokenHandler = new JwtSecurityTokenHandler(); var tokenHandler = new JwtSecurityTokenHandler();
var key = Encoding.ASCII.GetBytes(_options.SigningKey); var key = Encoding.ASCII.GetBytes(_options.SigningKey);
var clock = SystemClock.Instance;
var tokenDescriptor = new SecurityTokenDescriptor var tokenDescriptor = new SecurityTokenDescriptor
{ {
Subject = new ClaimsIdentity(new[] Subject = new ClaimsIdentity([
{ new Claim(JwtRegisteredClaimNames.Sub, session.Id.ToString()),
new Claim(JwtRegisteredClaimNames.Sub, subjectId),
new Claim(JwtRegisteredClaimNames.Jti, Guid.NewGuid().ToString()), new Claim(JwtRegisteredClaimNames.Jti, Guid.NewGuid().ToString()),
new Claim(JwtRegisteredClaimNames.Iat, clock.GetCurrentInstant().ToUnixTimeSeconds().ToString(), new Claim(JwtRegisteredClaimNames.Iat, clock.GetCurrentInstant().ToUnixTimeSeconds().ToString(),
ClaimValueTypes.Integer64), ClaimValueTypes.Integer64),
new Claim("client_id", client.Id.ToString()) new Claim("client_id", client.Id.ToString())
}), ]),
Expires = expiresAt.ToDateTimeUtc(), Expires = expiresAt.ToDateTimeUtc(),
Issuer = _options.IssuerUri, Issuer = _options.IssuerUri,
Audience = client.Id.ToString(), Audience = client.Id.ToString(),
SigningCredentials = new SigningCredentials( SigningCredentials = new SigningCredentials(
new SymmetricSecurityKey(key), new SymmetricSecurityKey(key),
SecurityAlgorithms.HmacSha256Signature) SecurityAlgorithms.HmacSha256Signature
)
}; };
// Add scopes as claims if provided, otherwise use client's default scopes // Add scopes as claims if provided, otherwise use client's default scopes
@ -163,11 +177,11 @@ public class OidcProviderService(
} }
} }
private static readonly Dictionary<string, AuthorizationCodeInfo> _authorizationCodes = new(); // Authorization codes are now managed through ICacheService
public async Task<string> GenerateAuthorizationCodeAsync( public async Task<string> GenerateAuthorizationCodeAsync(
Guid clientId, Guid clientId,
string userId, Guid userId,
string redirectUri, string redirectUri,
IEnumerable<string> scopes, IEnumerable<string> scopes,
string? codeChallenge = null, string? codeChallenge = null,
@ -175,14 +189,15 @@ public class OidcProviderService(
string? nonce = null) string? nonce = null)
{ {
// Generate a random code // Generate a random code
var clock = SystemClock.Instance;
var code = GenerateRandomString(32); var code = GenerateRandomString(32);
var now = clock.GetCurrentInstant(); var now = clock.GetCurrentInstant();
// Store the code with its metadata // Create the authorization code info
_authorizationCodes[code] = new AuthorizationCodeInfo var authCodeInfo = new AuthorizationCodeInfo
{ {
ClientId = clientId, ClientId = clientId,
UserId = userId, AccountId = userId,
RedirectUri = redirectUri, RedirectUri = redirectUri,
Scopes = scopes.ToList(), Scopes = scopes.ToList(),
CodeChallenge = codeChallenge, CodeChallenge = codeChallenge,
@ -192,6 +207,10 @@ public class OidcProviderService(
CreatedAt = now CreatedAt = now
}; };
// Store the code with its metadata in the cache
var cacheKey = $"auth:code:{code}";
await cache.SetAsync(cacheKey, authCodeInfo, _options.AuthorizationCodeLifetime);
logger.LogInformation("Generated authorization code for client {ClientId} and user {UserId}", clientId, userId); logger.LogInformation("Generated authorization code for client {ClientId} and user {UserId}", clientId, userId);
return code; return code;
} }
@ -200,28 +219,34 @@ public class OidcProviderService(
string code, string code,
Guid clientId, Guid clientId,
string? redirectUri = null, string? redirectUri = null,
string? codeVerifier = null) string? codeVerifier = null
)
{ {
if (!_authorizationCodes.TryGetValue(code, out var authCode) || authCode == null) var cacheKey = $"auth:code:{code}";
var (found, authCode) = await cache.GetAsyncWithStatus<AuthorizationCodeInfo>(cacheKey);
if (!found || authCode == null)
{ {
logger.LogWarning("Authorization code not found: {Code}", code); logger.LogWarning("Authorization code not found: {Code}", code);
return null; return null;
} }
var clock = SystemClock.Instance;
var now = clock.GetCurrentInstant(); var now = clock.GetCurrentInstant();
// Check if code has expired // Check if the code has expired
if (now > authCode.Expiration) if (now > authCode.Expiration)
{ {
logger.LogWarning("Authorization code expired: {Code}", code); logger.LogWarning("Authorization code expired: {Code}", code);
_authorizationCodes.Remove(code); await cache.RemoveAsync(cacheKey);
return null; return null;
} }
// Verify client ID matches // Verify client ID matches
if (authCode.ClientId != clientId) if (authCode.ClientId != clientId)
{ {
logger.LogWarning("Client ID mismatch for code {Code}. Expected: {ExpectedClientId}, Actual: {ActualClientId}", logger.LogWarning(
"Client ID mismatch for code {Code}. Expected: {ExpectedClientId}, Actual: {ActualClientId}",
code, authCode.ClientId, clientId); code, authCode.ClientId, clientId);
return null; return null;
} }
@ -256,8 +281,8 @@ public class OidcProviderService(
} }
} }
// Code is valid, remove it from the store (codes are single-use) // Code is valid, remove it from the cache (codes are single-use)
_authorizationCodes.Remove(code); await cache.RemoveAsync(cacheKey);
return authCode; return authCode;
} }

View File

@ -4,6 +4,8 @@ using Microsoft.AspNetCore.Mvc.RazorPages;
using DysonNetwork.Sphere.Auth.OidcProvider.Services; using DysonNetwork.Sphere.Auth.OidcProvider.Services;
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
using DysonNetwork.Sphere.Auth.OidcProvider.Responses;
using DysonNetwork.Sphere.Developer;
namespace DysonNetwork.Sphere.Pages.Auth; namespace DysonNetwork.Sphere.Pages.Auth;
@ -51,6 +53,12 @@ public class AuthorizeModel(OidcProviderService oidcService) : PageModel
public async Task<IActionResult> OnGetAsync() public async Task<IActionResult> OnGetAsync()
{ {
if (HttpContext.Items["CurrentUser"] is not Sphere.Account.Account)
{
var returnUrl = Uri.EscapeDataString($"{Request.Path}{Request.QueryString}");
return RedirectToPage($"/Auth/Login?returnUrl={returnUrl}");
}
if (string.IsNullOrEmpty(ClientIdString) || !Guid.TryParse(ClientIdString, out var clientId)) if (string.IsNullOrEmpty(ClientIdString) || !Guid.TryParse(ClientIdString, out var clientId))
{ {
ModelState.AddModelError("client_id", "Invalid client_id format"); ModelState.AddModelError("client_id", "Invalid client_id format");
@ -66,6 +74,14 @@ public class AuthorizeModel(OidcProviderService oidcService) : PageModel
return NotFound("Client not found"); return NotFound("Client not found");
} }
if (client.Status != CustomAppStatus.Developing)
{
// Validate redirect URI for non-Developing apps
if (!string.IsNullOrEmpty(RedirectUri) && !(client.RedirectUris?.Contains(RedirectUri) ?? false))
return BadRequest(
new ErrorResponse { Error = "invalid_request", ErrorDescription = "Invalid redirect_uri" });
}
AppName = client.Name; AppName = client.Name;
AppLogo = client.LogoUri; AppLogo = client.LogoUri;
AppUri = client.ClientUri; AppUri = client.ClientUri;
@ -76,7 +92,9 @@ public class AuthorizeModel(OidcProviderService oidcService) : PageModel
public async Task<IActionResult> OnPostAsync(bool allow) public async Task<IActionResult> OnPostAsync(bool allow)
{ {
// First validate the client ID if (HttpContext.Items["CurrentUser"] is not Sphere.Account.Account currentUser) return Unauthorized();
// First, validate the client ID
if (string.IsNullOrEmpty(ClientIdString) || !Guid.TryParse(ClientIdString, out var clientId)) if (string.IsNullOrEmpty(ClientIdString) || !Guid.TryParse(ClientIdString, out var clientId))
{ {
ModelState.AddModelError("client_id", "Invalid client_id format"); ModelState.AddModelError("client_id", "Invalid client_id format");
@ -85,7 +103,7 @@ public class AuthorizeModel(OidcProviderService oidcService) : PageModel
ClientId = clientId; ClientId = clientId;
// Check if client exists // Check if a client exists
var client = await oidcService.FindClientByIdAsync(ClientId); var client = await oidcService.FindClientByIdAsync(ClientId);
if (client == null) if (client == null)
{ {
@ -119,7 +137,7 @@ public class AuthorizeModel(OidcProviderService oidcService) : PageModel
// Generate authorization code // Generate authorization code
var authCode = await oidcService.GenerateAuthorizationCodeAsync( var authCode = await oidcService.GenerateAuthorizationCodeAsync(
clientId: ClientId, clientId: ClientId,
userId: User.Identity?.Name ?? string.Empty, userId: currentUser.Id,
redirectUri: RedirectUri, redirectUri: RedirectUri,
scopes: Scope?.Split(' ', StringSplitOptions.RemoveEmptyEntries) ?? Array.Empty<string>(), scopes: Scope?.Split(' ', StringSplitOptions.RemoveEmptyEntries) ?? Array.Empty<string>(),
codeChallenge: CodeChallenge, codeChallenge: CodeChallenge,