297 lines
12 KiB
C#
297 lines
12 KiB
C#
using System;
|
|
using System.IdentityModel.Tokens.Jwt;
|
|
using System.Security.Claims;
|
|
using System.Threading.Tasks;
|
|
using Microsoft.AspNetCore.Authentication;
|
|
using Microsoft.AspNetCore.Authorization;
|
|
using Microsoft.AspNetCore.Http;
|
|
using Microsoft.AspNetCore.Mvc;
|
|
using Microsoft.Extensions.DependencyInjection;
|
|
using Microsoft.Extensions.Logging;
|
|
using NodaTime;
|
|
using DysonNetwork.Common.Models;
|
|
using DysonNetwork.Pass.Data;
|
|
using DysonNetwork.Pass.Features.Auth.Models;
|
|
using DysonNetwork.Pass.Features.Auth.Services;
|
|
using Microsoft.IdentityModel.Tokens;
|
|
|
|
// Use fully qualified names to avoid ambiguity
|
|
using CommonAccount = DysonNetwork.Common.Models.Account;
|
|
using CommonOidcUserInfo = DysonNetwork.Common.Models.OidcUserInfo;
|
|
|
|
namespace DysonNetwork.Pass.Features.Auth.OpenId;
|
|
|
|
[ApiController]
|
|
[Route("/auth/login")]
|
|
public class OidcController : ControllerBase
|
|
{
|
|
private const string StateCachePrefix = "oidc-state:";
|
|
private static readonly TimeSpan StateExpiration = TimeSpan.FromMinutes(15);
|
|
private readonly ILogger<OidcController> _logger;
|
|
private readonly IServiceProvider _serviceProvider;
|
|
private readonly PassDatabase _db;
|
|
private readonly IAccountService _accountService;
|
|
private readonly IAccountConnectionService _connectionService;
|
|
private readonly ICacheService _cache;
|
|
|
|
public OidcController(
|
|
IServiceProvider serviceProvider,
|
|
PassDatabase db,
|
|
IAccountService accountService,
|
|
IAccountConnectionService connectionService,
|
|
ICacheService cache,
|
|
ILogger<OidcController> logger)
|
|
{
|
|
_serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider));
|
|
_db = db ?? throw new ArgumentNullException(nameof(db));
|
|
_accountService = accountService ?? throw new ArgumentNullException(nameof(accountService));
|
|
_connectionService = connectionService ?? throw new ArgumentNullException(nameof(connectionService));
|
|
_cache = cache ?? throw new ArgumentNullException(nameof(cache));
|
|
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
|
|
}
|
|
|
|
[HttpGet("{provider}")]
|
|
public async Task<ActionResult> OidcLogin(
|
|
[FromRoute] string provider,
|
|
[FromQuery] string? returnUrl = "/",
|
|
[FromHeader(Name = "X-Device-Id")] string? deviceId = null)
|
|
{
|
|
try
|
|
{
|
|
var oidcService = GetOidcService(provider);
|
|
|
|
// If the user is already authenticated, treat as an account connection request
|
|
var currentUser = await HttpContext.AuthenticateAsync();
|
|
if (currentUser.Succeeded && currentUser.Principal?.Identity?.IsAuthenticated == true)
|
|
{
|
|
var state = Guid.NewGuid().ToString();
|
|
var nonce = Guid.NewGuid().ToString();
|
|
|
|
// Get the current user's account ID
|
|
var accountId = currentUser.Principal.FindFirstValue(ClaimTypes.NameIdentifier);
|
|
if (string.IsNullOrEmpty(accountId))
|
|
{
|
|
_logger.LogWarning("Authenticated user does not have a valid account ID");
|
|
return Unauthorized();
|
|
}
|
|
|
|
// Create and store connection state
|
|
var oidcState = OidcState.ForConnection(accountId, provider, nonce, deviceId);
|
|
await _cache.SetAsync($"{StateCachePrefix}{state}", oidcState, StateExpiration);
|
|
|
|
// The state parameter sent to the provider is the GUID key for the cache.
|
|
var authUrl = oidcService.GetAuthorizationUrl(state, nonce);
|
|
return Redirect(authUrl);
|
|
}
|
|
else // Otherwise, proceed with the login / registration flow
|
|
{
|
|
var state = Guid.NewGuid().ToString();
|
|
var nonce = Guid.NewGuid().ToString();
|
|
|
|
// Store the state and nonce for validation later
|
|
var oidcState = OidcState.ForLogin(returnUrl ?? "/", deviceId);
|
|
oidcState.Provider = provider;
|
|
oidcState.Nonce = nonce;
|
|
await _cache.SetAsync($"{StateCachePrefix}{state}", oidcState, StateExpiration);
|
|
|
|
var authUrl = oidcService.GetAuthorizationUrl(state, nonce);
|
|
return Redirect(authUrl);
|
|
}
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
return BadRequest($"Error initiating OpenID Connect flow: {ex.Message}");
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Mobile Apple Sign In endpoint
|
|
/// Handles Apple authentication directly from mobile apps
|
|
/// </summary>
|
|
[HttpPost("apple/mobile")]
|
|
public async Task<ActionResult<Models.AuthChallenge>> AppleMobileSignIn(
|
|
[FromBody] AppleMobileSignInRequest request)
|
|
{
|
|
try
|
|
{
|
|
// Get Apple OIDC service
|
|
if (GetOidcService("apple") is not AppleOidcService appleService)
|
|
return StatusCode(503, "Apple OIDC service not available");
|
|
|
|
// Prepare callback data for processing
|
|
var callbackData = new OidcCallbackData
|
|
{
|
|
IdToken = request.IdentityToken,
|
|
Code = request.AuthorizationCode,
|
|
};
|
|
|
|
// Process the authentication
|
|
var userInfo = await appleService.ProcessCallbackAsync(callbackData);
|
|
|
|
// Find or create user account using existing logic
|
|
var account = await FindOrCreateAccount(userInfo, "apple");
|
|
|
|
// Create session using the OIDC service
|
|
var challenge = await appleService.CreateChallengeForUserAsync(
|
|
userInfo,
|
|
account,
|
|
HttpContext,
|
|
request.DeviceId
|
|
);
|
|
|
|
if (challenge == null)
|
|
{
|
|
return StatusCode(StatusCodes.Status500InternalServerError, "Failed to create authentication challenge");
|
|
}
|
|
|
|
return Ok(challenge);
|
|
}
|
|
catch (SecurityTokenValidationException ex)
|
|
{
|
|
return Unauthorized($"Invalid identity token: {ex.Message}");
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
// Log the error
|
|
return StatusCode(500, $"Authentication failed: {ex.Message}");
|
|
}
|
|
}
|
|
|
|
private async Task<IActionResult> HandleLogin(OidcState oidcState, OidcUserInfo userInfo)
|
|
{
|
|
try
|
|
{
|
|
// Find or create the account
|
|
var account = await _accountService.FindOrCreateAccountAsync(userInfo, oidcState.Provider ?? throw new InvalidOperationException("Provider not specified"));
|
|
if (account == null)
|
|
{
|
|
_logger.LogError("Failed to find or create account for user {UserId}", userInfo.UserId);
|
|
return StatusCode(StatusCodes.Status500InternalServerError, "Failed to process your account");
|
|
}
|
|
|
|
// Create a new session
|
|
var session = await _connectionService.CreateSessionAsync(account, oidcState.DeviceId);
|
|
if (session == null)
|
|
{
|
|
_logger.LogError("Failed to create session for account {AccountId}", account.Id);
|
|
return StatusCode(StatusCodes.Status500InternalServerError, "Failed to create session");
|
|
}
|
|
|
|
// Create auth tokens
|
|
var tokens = await _accountService.GenerateAuthTokensAsync(account, session.Id.ToString());
|
|
|
|
// Return the tokens and redirect URL
|
|
return Ok(new
|
|
{
|
|
tokens.AccessToken,
|
|
tokens.RefreshToken,
|
|
tokens.ExpiresIn,
|
|
ReturnUrl = oidcState.ReturnUrl ?? "/"
|
|
});
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "Error handling OIDC login for user {UserId}", userInfo.UserId);
|
|
return StatusCode(StatusCodes.Status500InternalServerError, "An error occurred during login");
|
|
}
|
|
}
|
|
|
|
private async Task<IActionResult> HandleAccountConnection(OidcState oidcState, OidcUserInfo userInfo)
|
|
{
|
|
try
|
|
{
|
|
// Get the current user's account
|
|
if (!Guid.TryParse(oidcState.AccountId, out var accountId))
|
|
{
|
|
_logger.LogError("Invalid account ID format: {AccountId}", oidcState.AccountId);
|
|
return BadRequest("Invalid account ID format");
|
|
}
|
|
|
|
var account = await _accountService.GetAccountByIdAsync(accountId);
|
|
if (account == null)
|
|
{
|
|
_logger.LogError("Account not found for ID {AccountId}", accountId);
|
|
return Unauthorized();
|
|
}
|
|
|
|
// Add the OIDC connection to the account
|
|
var connection = await _connectionService.AddConnectionAsync(account, userInfo, oidcState.Provider!);
|
|
if (connection == null)
|
|
{
|
|
_logger.LogError("Failed to add OIDC connection for account {AccountId}", account.Id);
|
|
return StatusCode(StatusCodes.Status500InternalServerError, "Failed to add OIDC connection");
|
|
}
|
|
|
|
// Return success
|
|
return Ok(new { Success = true });
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "Error handling OIDC account connection for user {UserId}", userInfo.UserId);
|
|
return StatusCode(StatusCodes.Status500InternalServerError, "An error occurred while connecting your account");
|
|
}
|
|
}
|
|
|
|
private OidcService GetOidcService(string provider)
|
|
{
|
|
return provider.ToLower() switch
|
|
{
|
|
"apple" => _serviceProvider.GetRequiredService<AppleOidcService>(),
|
|
"google" => _serviceProvider.GetRequiredService<GoogleOidcService>(),
|
|
"microsoft" => _serviceProvider.GetRequiredService<MicrosoftOidcService>(),
|
|
"discord" => _serviceProvider.GetRequiredService<DiscordOidcService>(),
|
|
"github" => _serviceProvider.GetRequiredService<GitHubOidcService>(),
|
|
"afdian" => _serviceProvider.GetRequiredService<AfdianOidcService>(),
|
|
_ => throw new ArgumentException($"Unsupported provider: {provider}")
|
|
};
|
|
}
|
|
|
|
private async Task<CommonAccount> FindOrCreateAccount(CommonOidcUserInfo userInfo, string provider)
|
|
{
|
|
if (string.IsNullOrEmpty(userInfo.Email))
|
|
throw new ArgumentException("Email is required for account creation");
|
|
|
|
// Find or create the account connection
|
|
var connection = await _connectionService.FindOrCreateConnection(userInfo, provider);
|
|
|
|
// If connection already has an account, return it
|
|
if (!string.IsNullOrEmpty(connection.AccountId))
|
|
{
|
|
if (Guid.TryParse(connection.AccountId, out var accountId))
|
|
{
|
|
var existingAccount = await _accountService.GetAccountByIdAsync(accountId);
|
|
if (existingAccount != null)
|
|
{
|
|
await _connectionService.UpdateConnection(connection, userInfo);
|
|
return existingAccount;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check if account exists by email
|
|
var account = await _accountService.FindByEmailAsync(userInfo.Email);
|
|
if (account == null)
|
|
{
|
|
// Create new account using the account service
|
|
account = new CommonAccount
|
|
{
|
|
Id = Guid.NewGuid().ToString(),
|
|
Email = userInfo.Email,
|
|
Name = userInfo.Name ?? userInfo.Email,
|
|
CreatedAt = SystemClock.Instance.GetCurrentInstant()
|
|
};
|
|
|
|
// Save the new account
|
|
account = await _accountService.CreateAccountAsync(account);
|
|
}
|
|
|
|
// Update connection with account ID if needed
|
|
if (string.IsNullOrEmpty(connection.AccountId))
|
|
{
|
|
connection.AccountId = account.Id;
|
|
await _connectionService.UpdateConnection(connection, userInfo);
|
|
}
|
|
|
|
return account;
|
|
}
|
|
} |