:drunk: No idea what did AI did
This commit is contained in:
@ -1,47 +1,83 @@
|
||||
using System;
|
||||
using System.IdentityModel.Tokens.Jwt;
|
||||
using System.Security.Claims;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Authentication;
|
||||
using Microsoft.AspNetCore.Authorization;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.AspNetCore.Mvc;
|
||||
using Microsoft.EntityFrameworkCore;
|
||||
using Microsoft.IdentityModel.Tokens;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using NodaTime;
|
||||
using DysonNetwork.Common.Models;
|
||||
using DysonNetwork.Pass.Data;
|
||||
using DysonNetwork.Sphere;
|
||||
using DysonNetwork.Pass.Features.Auth.Models;
|
||||
using DysonNetwork.Pass.Features.Auth.Services;
|
||||
using Microsoft.IdentityModel.Tokens;
|
||||
|
||||
// Use fully qualified names to avoid ambiguity
|
||||
using CommonAccount = DysonNetwork.Common.Models.Account;
|
||||
using CommonOidcUserInfo = DysonNetwork.Common.Models.OidcUserInfo;
|
||||
|
||||
namespace DysonNetwork.Pass.Features.Auth.OpenId;
|
||||
|
||||
[ApiController]
|
||||
[Route("/auth/login")]
|
||||
public class OidcController(
|
||||
IServiceProvider serviceProvider,
|
||||
PassDatabase passDb,
|
||||
AppDatabase sphereDb,
|
||||
AccountService accounts,
|
||||
ICacheService cache
|
||||
)
|
||||
: ControllerBase
|
||||
public class OidcController : ControllerBase
|
||||
{
|
||||
private const string StateCachePrefix = "oidc-state:";
|
||||
private static readonly TimeSpan StateExpiration = TimeSpan.FromMinutes(15);
|
||||
private readonly ILogger<OidcController> _logger;
|
||||
private readonly IServiceProvider _serviceProvider;
|
||||
private readonly PassDatabase _db;
|
||||
private readonly IAccountService _accountService;
|
||||
private readonly IAccountConnectionService _connectionService;
|
||||
private readonly ICacheService _cache;
|
||||
|
||||
public OidcController(
|
||||
IServiceProvider serviceProvider,
|
||||
PassDatabase db,
|
||||
IAccountService accountService,
|
||||
IAccountConnectionService connectionService,
|
||||
ICacheService cache,
|
||||
ILogger<OidcController> logger)
|
||||
{
|
||||
_serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider));
|
||||
_db = db ?? throw new ArgumentNullException(nameof(db));
|
||||
_accountService = accountService ?? throw new ArgumentNullException(nameof(accountService));
|
||||
_connectionService = connectionService ?? throw new ArgumentNullException(nameof(connectionService));
|
||||
_cache = cache ?? throw new ArgumentNullException(nameof(cache));
|
||||
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
|
||||
}
|
||||
|
||||
[HttpGet("{provider}")]
|
||||
public async Task<ActionResult> OidcLogin(
|
||||
[FromRoute] string provider,
|
||||
[FromQuery] string? returnUrl = "/",
|
||||
[FromHeader(Name = "X-Device-Id")] string? deviceId = null
|
||||
)
|
||||
[FromHeader(Name = "X-Device-Id")] string? deviceId = null)
|
||||
{
|
||||
try
|
||||
{
|
||||
var oidcService = GetOidcService(provider);
|
||||
|
||||
// If the user is already authenticated, treat as an account connection request
|
||||
if (HttpContext.Items["CurrentUser"] is Account currentUser)
|
||||
var currentUser = await HttpContext.AuthenticateAsync();
|
||||
if (currentUser.Succeeded && currentUser.Principal?.Identity?.IsAuthenticated == true)
|
||||
{
|
||||
var state = Guid.NewGuid().ToString();
|
||||
var nonce = Guid.NewGuid().ToString();
|
||||
|
||||
// Get the current user's account ID
|
||||
var accountId = currentUser.Principal.FindFirstValue(ClaimTypes.NameIdentifier);
|
||||
if (string.IsNullOrEmpty(accountId))
|
||||
{
|
||||
_logger.LogWarning("Authenticated user does not have a valid account ID");
|
||||
return Unauthorized();
|
||||
}
|
||||
|
||||
// Create and store connection state
|
||||
var oidcState = OidcState.ForConnection(currentUser.Id, provider, nonce, deviceId);
|
||||
await cache.SetAsync($"{StateCachePrefix}{state}", oidcState, StateExpiration);
|
||||
var oidcState = OidcState.ForConnection(accountId, provider, nonce, deviceId);
|
||||
await _cache.SetAsync($"{StateCachePrefix}{state}", oidcState, StateExpiration);
|
||||
|
||||
// The state parameter sent to the provider is the GUID key for the cache.
|
||||
var authUrl = oidcService.GetAuthorizationUrl(state, nonce);
|
||||
@ -49,12 +85,15 @@ public class OidcController(
|
||||
}
|
||||
else // Otherwise, proceed with the login / registration flow
|
||||
{
|
||||
var nonce = Guid.NewGuid().ToString();
|
||||
var state = Guid.NewGuid().ToString();
|
||||
var nonce = Guid.NewGuid().ToString();
|
||||
|
||||
// Create login state with return URL and device ID
|
||||
// Store the state and nonce for validation later
|
||||
var oidcState = OidcState.ForLogin(returnUrl ?? "/", deviceId);
|
||||
await cache.SetAsync($"{StateCachePrefix}{state}", oidcState, StateExpiration);
|
||||
oidcState.Provider = provider;
|
||||
oidcState.Nonce = nonce;
|
||||
await _cache.SetAsync($"{StateCachePrefix}{state}", oidcState, StateExpiration);
|
||||
|
||||
var authUrl = oidcService.GetAuthorizationUrl(state, nonce);
|
||||
return Redirect(authUrl);
|
||||
}
|
||||
@ -70,7 +109,7 @@ public class OidcController(
|
||||
/// Handles Apple authentication directly from mobile apps
|
||||
/// </summary>
|
||||
[HttpPost("apple/mobile")]
|
||||
public async Task<ActionResult<AuthChallenge>> AppleMobileSignIn(
|
||||
public async Task<ActionResult<Models.AuthChallenge>> AppleMobileSignIn(
|
||||
[FromBody] AppleMobileSignInRequest request)
|
||||
{
|
||||
try
|
||||
@ -100,6 +139,11 @@ public class OidcController(
|
||||
request.DeviceId
|
||||
);
|
||||
|
||||
if (challenge == null)
|
||||
{
|
||||
return StatusCode(StatusCodes.Status500InternalServerError, "Failed to create authentication challenge");
|
||||
}
|
||||
|
||||
return Ok(challenge);
|
||||
}
|
||||
catch (SecurityTokenValidationException ex)
|
||||
@ -113,85 +157,141 @@ public class OidcController(
|
||||
}
|
||||
}
|
||||
|
||||
private async Task<IActionResult> HandleLogin(OidcState oidcState, OidcUserInfo userInfo)
|
||||
{
|
||||
try
|
||||
{
|
||||
// Find or create the account
|
||||
var account = await _accountService.FindOrCreateAccountAsync(userInfo, oidcState.Provider ?? throw new InvalidOperationException("Provider not specified"));
|
||||
if (account == null)
|
||||
{
|
||||
_logger.LogError("Failed to find or create account for user {UserId}", userInfo.UserId);
|
||||
return StatusCode(StatusCodes.Status500InternalServerError, "Failed to process your account");
|
||||
}
|
||||
|
||||
// Create a new session
|
||||
var session = await _connectionService.CreateSessionAsync(account, oidcState.DeviceId);
|
||||
if (session == null)
|
||||
{
|
||||
_logger.LogError("Failed to create session for account {AccountId}", account.Id);
|
||||
return StatusCode(StatusCodes.Status500InternalServerError, "Failed to create session");
|
||||
}
|
||||
|
||||
// Create auth tokens
|
||||
var tokens = await _accountService.GenerateAuthTokensAsync(account, session.Id.ToString());
|
||||
|
||||
// Return the tokens and redirect URL
|
||||
return Ok(new
|
||||
{
|
||||
tokens.AccessToken,
|
||||
tokens.RefreshToken,
|
||||
tokens.ExpiresIn,
|
||||
ReturnUrl = oidcState.ReturnUrl ?? "/"
|
||||
});
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error handling OIDC login for user {UserId}", userInfo.UserId);
|
||||
return StatusCode(StatusCodes.Status500InternalServerError, "An error occurred during login");
|
||||
}
|
||||
}
|
||||
|
||||
private async Task<IActionResult> HandleAccountConnection(OidcState oidcState, OidcUserInfo userInfo)
|
||||
{
|
||||
try
|
||||
{
|
||||
// Get the current user's account
|
||||
if (!Guid.TryParse(oidcState.AccountId, out var accountId))
|
||||
{
|
||||
_logger.LogError("Invalid account ID format: {AccountId}", oidcState.AccountId);
|
||||
return BadRequest("Invalid account ID format");
|
||||
}
|
||||
|
||||
var account = await _accountService.GetAccountByIdAsync(accountId);
|
||||
if (account == null)
|
||||
{
|
||||
_logger.LogError("Account not found for ID {AccountId}", accountId);
|
||||
return Unauthorized();
|
||||
}
|
||||
|
||||
// Add the OIDC connection to the account
|
||||
var connection = await _connectionService.AddConnectionAsync(account, userInfo, oidcState.Provider!);
|
||||
if (connection == null)
|
||||
{
|
||||
_logger.LogError("Failed to add OIDC connection for account {AccountId}", account.Id);
|
||||
return StatusCode(StatusCodes.Status500InternalServerError, "Failed to add OIDC connection");
|
||||
}
|
||||
|
||||
// Return success
|
||||
return Ok(new { Success = true });
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex, "Error handling OIDC account connection for user {UserId}", userInfo.UserId);
|
||||
return StatusCode(StatusCodes.Status500InternalServerError, "An error occurred while connecting your account");
|
||||
}
|
||||
}
|
||||
|
||||
private OidcService GetOidcService(string provider)
|
||||
{
|
||||
return provider.ToLower() switch
|
||||
{
|
||||
"apple" => serviceProvider.GetRequiredService<AppleOidcService>(),
|
||||
"google" => serviceProvider.GetRequiredService<GoogleOidcService>(),
|
||||
"microsoft" => serviceProvider.GetRequiredService<MicrosoftOidcService>(),
|
||||
"discord" => serviceProvider.GetRequiredService<DiscordOidcService>(),
|
||||
"github" => serviceProvider.GetRequiredService<GitHubOidcService>(),
|
||||
"afdian" => serviceProvider.GetRequiredService<AfdianOidcService>(),
|
||||
"apple" => _serviceProvider.GetRequiredService<AppleOidcService>(),
|
||||
"google" => _serviceProvider.GetRequiredService<GoogleOidcService>(),
|
||||
"microsoft" => _serviceProvider.GetRequiredService<MicrosoftOidcService>(),
|
||||
"discord" => _serviceProvider.GetRequiredService<DiscordOidcService>(),
|
||||
"github" => _serviceProvider.GetRequiredService<GitHubOidcService>(),
|
||||
"afdian" => _serviceProvider.GetRequiredService<AfdianOidcService>(),
|
||||
_ => throw new ArgumentException($"Unsupported provider: {provider}")
|
||||
};
|
||||
}
|
||||
|
||||
private async Task<Account> FindOrCreateAccount(OidcUserInfo userInfo, string provider)
|
||||
private async Task<CommonAccount> FindOrCreateAccount(CommonOidcUserInfo userInfo, string provider)
|
||||
{
|
||||
if (string.IsNullOrEmpty(userInfo.Email))
|
||||
throw new ArgumentException("Email is required for account creation");
|
||||
|
||||
// Check if an account exists by email
|
||||
var existingAccount = await accounts.LookupAccount(userInfo.Email);
|
||||
if (existingAccount != null)
|
||||
// Find or create the account connection
|
||||
var connection = await _connectionService.FindOrCreateConnection(userInfo, provider);
|
||||
|
||||
// If connection already has an account, return it
|
||||
if (!string.IsNullOrEmpty(connection.AccountId))
|
||||
{
|
||||
// Check if this provider connection already exists
|
||||
var existingConnection = await passDb.AccountConnections
|
||||
.FirstOrDefaultAsync(c => c.Provider == provider &&
|
||||
c.ProvidedIdentifier == userInfo.UserId &&
|
||||
c.AccountId == existingAccount.Id
|
||||
);
|
||||
|
||||
// If no connection exists, create one
|
||||
if (existingConnection != null)
|
||||
if (Guid.TryParse(connection.AccountId, out var accountId))
|
||||
{
|
||||
await passDb.AccountConnections
|
||||
.Where(c => c.AccountId == existingAccount.Id &&
|
||||
c.Provider == provider &&
|
||||
c.ProvidedIdentifier == userInfo.UserId)
|
||||
.ExecuteUpdateAsync(s => s
|
||||
.SetProperty(c => c.LastUsedAt, SystemClock.Instance.GetCurrentInstant())
|
||||
.SetProperty(c => c.Meta, userInfo.ToMetadata()));
|
||||
|
||||
return existingAccount;
|
||||
var existingAccount = await _accountService.GetAccountByIdAsync(accountId);
|
||||
if (existingAccount != null)
|
||||
{
|
||||
await _connectionService.UpdateConnection(connection, userInfo);
|
||||
return existingAccount;
|
||||
}
|
||||
}
|
||||
|
||||
var connection = new AccountConnection
|
||||
{
|
||||
AccountId = existingAccount.Id,
|
||||
Provider = provider,
|
||||
ProvidedIdentifier = userInfo.UserId!,
|
||||
AccessToken = userInfo.AccessToken,
|
||||
RefreshToken = userInfo.RefreshToken,
|
||||
LastUsedAt = SystemClock.Instance.GetCurrentInstant(),
|
||||
Meta = userInfo.ToMetadata()
|
||||
};
|
||||
|
||||
await passDb.AccountConnections.AddAsync(connection);
|
||||
await passDb.SaveChangesAsync();
|
||||
|
||||
return existingAccount;
|
||||
}
|
||||
|
||||
// Create new account using the AccountService
|
||||
var newAccount = await accounts.CreateAccount(userInfo);
|
||||
|
||||
// Create the provider connection
|
||||
var newConnection = new AccountConnection
|
||||
// Check if account exists by email
|
||||
var account = await _accountService.FindByEmailAsync(userInfo.Email);
|
||||
if (account == null)
|
||||
{
|
||||
AccountId = newAccount.Id,
|
||||
Provider = provider,
|
||||
ProvidedIdentifier = userInfo.UserId!,
|
||||
AccessToken = userInfo.AccessToken,
|
||||
RefreshToken = userInfo.RefreshToken,
|
||||
LastUsedAt = SystemClock.Instance.GetCurrentInstant(),
|
||||
Meta = userInfo.ToMetadata()
|
||||
};
|
||||
// Create new account using the account service
|
||||
account = new CommonAccount
|
||||
{
|
||||
Id = Guid.NewGuid().ToString(),
|
||||
Email = userInfo.Email,
|
||||
Name = userInfo.Name ?? userInfo.Email,
|
||||
CreatedAt = SystemClock.Instance.GetCurrentInstant()
|
||||
};
|
||||
|
||||
// Save the new account
|
||||
account = await _accountService.CreateAccountAsync(account);
|
||||
}
|
||||
|
||||
await passDb.AccountConnections.Add(newConnection);
|
||||
await passDb.SaveChangesAsync();
|
||||
// Update connection with account ID if needed
|
||||
if (string.IsNullOrEmpty(connection.AccountId))
|
||||
{
|
||||
connection.AccountId = account.Id;
|
||||
await _connectionService.UpdateConnection(connection, userInfo);
|
||||
}
|
||||
|
||||
return newAccount;
|
||||
return account;
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user