Oidc auto approval and session reuse

This commit is contained in:
LittleSheep 2025-06-29 17:46:17 +08:00
parent 0226bf8fa3
commit c0879d30d4
2 changed files with 138 additions and 19 deletions

View File

@ -38,6 +38,20 @@ public class OidcProviderService(
.FirstOrDefaultAsync(c => c.Id == appId); .FirstOrDefaultAsync(c => c.Id == appId);
} }
public async Task<Session?> FindValidSessionAsync(Guid accountId, Guid clientId)
{
var now = SystemClock.Instance.GetCurrentInstant();
return await db.AuthSessions
.Include(s => s.Challenge)
.Where(s => s.AccountId == accountId &&
s.AppId == clientId &&
(s.ExpiredAt == null || s.ExpiredAt > now) &&
s.Challenge.Type == ChallengeType.OAuth)
.OrderByDescending(s => s.CreatedAt)
.FirstOrDefaultAsync();
}
public async Task<bool> ValidateClientCredentialsAsync(Guid clientId, string clientSecret) public async Task<bool> ValidateClientCredentialsAsync(Guid clientId, string clientSecret)
{ {
var client = await FindClientByIdAsync(clientId); var client = await FindClientByIdAsync(clientId);
@ -76,7 +90,7 @@ public class OidcProviderService(
var account = await db.Accounts.Where(a => a.Id == authCode.AccountId).FirstOrDefaultAsync(); var account = await db.Accounts.Where(a => a.Id == authCode.AccountId).FirstOrDefaultAsync();
if (account is null) throw new InvalidOperationException("Account was not found"); if (account is null) throw new InvalidOperationException("Account was not found");
session = await auth.CreateSessionForOidcAsync(account, now); session = await auth.CreateSessionForOidcAsync(account, now, client.Id);
scopes = authCode.Scopes; scopes = authCode.Scopes;
} }
else if (sessionId.HasValue) else if (sessionId.HasValue)
@ -207,6 +221,44 @@ public class OidcProviderService(
return string.Equals(secret, hashedSecret, StringComparison.Ordinal); return string.Equals(secret, hashedSecret, StringComparison.Ordinal);
} }
public async Task<string> GenerateAuthorizationCodeForExistingSessionAsync(
Session session,
Guid clientId,
string redirectUri,
IEnumerable<string> scopes,
string? codeChallenge = null,
string? codeChallengeMethod = null,
string? nonce = null)
{
var clock = SystemClock.Instance;
var now = clock.GetCurrentInstant();
var code = Guid.NewGuid().ToString("N");
// Update the session's last activity time
await db.AuthSessions.Where(s => s.Id == session.Id)
.ExecuteUpdateAsync(s => s.SetProperty(s => s.LastGrantedAt, now));
// Create the authorization code info
var authCodeInfo = new AuthorizationCodeInfo
{
ClientId = clientId,
AccountId = session.AccountId,
RedirectUri = redirectUri,
Scopes = scopes.ToList(),
CodeChallenge = codeChallenge,
CodeChallengeMethod = codeChallengeMethod,
Nonce = nonce,
CreatedAt = now
};
// Store the code with its metadata in the cache
var cacheKey = $"auth:code:{code}";
await cache.SetAsync(cacheKey, authCodeInfo, _options.AuthorizationCodeLifetime);
logger.LogInformation("Generated authorization code for client {ClientId} and user {UserId}", clientId, session.AccountId);
return code;
}
public async Task<string> GenerateAuthorizationCodeAsync( public async Task<string> GenerateAuthorizationCodeAsync(
Guid clientId, Guid clientId,
Guid userId, Guid userId,
@ -214,7 +266,8 @@ public class OidcProviderService(
IEnumerable<string> scopes, IEnumerable<string> scopes,
string? codeChallenge = null, string? codeChallenge = null,
string? codeChallengeMethod = null, string? codeChallengeMethod = null,
string? nonce = null) string? nonce = null
)
{ {
// Generate a random code // Generate a random code
var clock = SystemClock.Instance; var clock = SystemClock.Instance;

View File

@ -1,9 +1,8 @@
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Mvc.RazorPages; using Microsoft.AspNetCore.Mvc.RazorPages;
using DysonNetwork.Sphere.Auth.OidcProvider.Services; using DysonNetwork.Sphere.Auth.OidcProvider.Services;
using Microsoft.EntityFrameworkCore;
using System.ComponentModel.DataAnnotations; using System.ComponentModel.DataAnnotations;
using DysonNetwork.Sphere.Auth;
using DysonNetwork.Sphere.Auth.OidcProvider.Responses; using DysonNetwork.Sphere.Auth.OidcProvider.Responses;
using DysonNetwork.Sphere.Developer; using DysonNetwork.Sphere.Developer;
@ -48,12 +47,14 @@ public class AuthorizeModel(OidcProviderService oidcService) : PageModel
public async Task<IActionResult> OnGetAsync() public async Task<IActionResult> OnGetAsync()
{ {
if (HttpContext.Items["CurrentUser"] is not Sphere.Account.Account) // First check if user is authenticated
if (HttpContext.Items["CurrentUser"] is not Sphere.Account.Account currentUser)
{ {
var returnUrl = Uri.EscapeDataString($"{Request.Path}{Request.QueryString}"); var returnUrl = Uri.EscapeDataString($"{Request.Path}{Request.QueryString}");
return RedirectToPage("/Auth/Login", new { returnUrl }); return RedirectToPage("/Auth/Login", new { returnUrl });
} }
// Validate client_id
if (string.IsNullOrEmpty(ClientIdString) || !Guid.TryParse(ClientIdString, out var clientId)) if (string.IsNullOrEmpty(ClientIdString) || !Guid.TryParse(ClientIdString, out var clientId))
{ {
ModelState.AddModelError("client_id", "Invalid client_id format"); ModelState.AddModelError("client_id", "Invalid client_id format");
@ -62,6 +63,7 @@ public class AuthorizeModel(OidcProviderService oidcService) : PageModel
ClientId = clientId; ClientId = clientId;
// Get client info
var client = await oidcService.FindClientByIdAsync(ClientId); var client = await oidcService.FindClientByIdAsync(ClientId);
if (client == null) if (client == null)
{ {
@ -69,14 +71,28 @@ public class AuthorizeModel(OidcProviderService oidcService) : PageModel
return NotFound("Client not found"); return NotFound("Client not found");
} }
// Validate redirect URI for non-Developing apps
if (client.Status != CustomAppStatus.Developing) if (client.Status != CustomAppStatus.Developing)
{ {
// Validate redirect URI for non-Developing apps
if (!string.IsNullOrEmpty(RedirectUri) && !(client.RedirectUris?.Contains(RedirectUri) ?? false)) if (!string.IsNullOrEmpty(RedirectUri) && !(client.RedirectUris?.Contains(RedirectUri) ?? false))
return BadRequest( {
new ErrorResponse { Error = "invalid_request", ErrorDescription = "Invalid redirect_uri" }); return BadRequest(new ErrorResponse
{
Error = "invalid_request",
ErrorDescription = "Invalid redirect_uri"
});
}
} }
// Check for an existing valid session
var existingSession = await oidcService.FindValidSessionAsync(currentUser.Id, clientId);
if (existingSession != null)
{
// Auto-approve since valid session exists
return await HandleApproval(currentUser, client, existingSession);
}
// Show authorization page
AppName = client.Name; AppName = client.Name;
AppLogo = client.LogoUri; AppLogo = client.LogoUri;
AppUri = client.ClientUri; AppUri = client.ClientUri;
@ -85,6 +101,56 @@ public class AuthorizeModel(OidcProviderService oidcService) : PageModel
return Page(); return Page();
} }
private async Task<IActionResult> HandleApproval(Sphere.Account.Account currentUser, CustomApp client, Session? existingSession = null)
{
if (string.IsNullOrEmpty(RedirectUri))
{
ModelState.AddModelError("redirect_uri", "No redirect_uri provided");
return BadRequest("No redirect_uri provided");
}
string authCode;
if (existingSession != null)
{
// Reuse existing session
authCode = await oidcService.GenerateAuthorizationCodeForExistingSessionAsync(
session: existingSession,
clientId: ClientId,
redirectUri: RedirectUri,
scopes: Scope?.Split(' ', StringSplitOptions.RemoveEmptyEntries) ?? [],
codeChallenge: CodeChallenge,
codeChallengeMethod: CodeChallengeMethod,
nonce: Nonce
);
}
else
{
// Create new session (existing flow)
authCode = await oidcService.GenerateAuthorizationCodeAsync(
clientId: ClientId,
userId: currentUser.Id,
redirectUri: RedirectUri,
scopes: Scope?.Split(' ', StringSplitOptions.RemoveEmptyEntries) ?? [],
codeChallenge: CodeChallenge,
codeChallengeMethod: CodeChallengeMethod,
nonce: Nonce
);
}
// Build the redirect URI with the authorization code
var redirectUriBuilder = new UriBuilder(RedirectUri);
var query = System.Web.HttpUtility.ParseQueryString(redirectUriBuilder.Query);
query["code"] = authCode;
if (!string.IsNullOrEmpty(State))
query["state"] = State;
if (!string.IsNullOrEmpty(Scope))
query["scope"] = Scope;
redirectUriBuilder.Query = query.ToString();
return Redirect(redirectUriBuilder.ToString());
}
public async Task<IActionResult> OnPostAsync(bool allow) public async Task<IActionResult> OnPostAsync(bool allow)
{ {
if (HttpContext.Items["CurrentUser"] is not Sphere.Account.Account currentUser) return Unauthorized(); if (HttpContext.Items["CurrentUser"] is not Sphere.Account.Account currentUser) return Unauthorized();