Files
Swarm/DysonNetwork.Pass/Features/Auth/Services/OpenId/OidcController.cs

297 lines
12 KiB
C#

using System;
using System.IdentityModel.Tokens.Jwt;
using System.Security.Claims;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using NodaTime;
using DysonNetwork.Common.Models;
using DysonNetwork.Pass.Data;
using DysonNetwork.Pass.Features.Auth.Models;
using DysonNetwork.Pass.Features.Auth.Services;
using Microsoft.IdentityModel.Tokens;
// Use fully qualified names to avoid ambiguity
using CommonAccount = DysonNetwork.Common.Models.Account;
using CommonOidcUserInfo = DysonNetwork.Common.Models.OidcUserInfo;
namespace DysonNetwork.Pass.Features.Auth.OpenId;
[ApiController]
[Route("/auth/login")]
public class OidcController : ControllerBase
{
private const string StateCachePrefix = "oidc-state:";
private static readonly TimeSpan StateExpiration = TimeSpan.FromMinutes(15);
private readonly ILogger<OidcController> _logger;
private readonly IServiceProvider _serviceProvider;
private readonly PassDatabase _db;
private readonly IAccountService _accountService;
private readonly IAccountConnectionService _connectionService;
private readonly ICacheService _cache;
public OidcController(
IServiceProvider serviceProvider,
PassDatabase db,
IAccountService accountService,
IAccountConnectionService connectionService,
ICacheService cache,
ILogger<OidcController> logger)
{
_serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider));
_db = db ?? throw new ArgumentNullException(nameof(db));
_accountService = accountService ?? throw new ArgumentNullException(nameof(accountService));
_connectionService = connectionService ?? throw new ArgumentNullException(nameof(connectionService));
_cache = cache ?? throw new ArgumentNullException(nameof(cache));
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
}
[HttpGet("{provider}")]
public async Task<ActionResult> OidcLogin(
[FromRoute] string provider,
[FromQuery] string? returnUrl = "/",
[FromHeader(Name = "X-Device-Id")] string? deviceId = null)
{
try
{
var oidcService = GetOidcService(provider);
// If the user is already authenticated, treat as an account connection request
var currentUser = await HttpContext.AuthenticateAsync();
if (currentUser.Succeeded && currentUser.Principal?.Identity?.IsAuthenticated == true)
{
var state = Guid.NewGuid().ToString();
var nonce = Guid.NewGuid().ToString();
// Get the current user's account ID
var accountId = currentUser.Principal.FindFirstValue(ClaimTypes.NameIdentifier);
if (string.IsNullOrEmpty(accountId))
{
_logger.LogWarning("Authenticated user does not have a valid account ID");
return Unauthorized();
}
// Create and store connection state
var oidcState = OidcState.ForConnection(accountId, provider, nonce, deviceId);
await _cache.SetAsync($"{StateCachePrefix}{state}", oidcState, StateExpiration);
// The state parameter sent to the provider is the GUID key for the cache.
var authUrl = oidcService.GetAuthorizationUrl(state, nonce);
return Redirect(authUrl);
}
else // Otherwise, proceed with the login / registration flow
{
var state = Guid.NewGuid().ToString();
var nonce = Guid.NewGuid().ToString();
// Store the state and nonce for validation later
var oidcState = OidcState.ForLogin(returnUrl ?? "/", deviceId);
oidcState.Provider = provider;
oidcState.Nonce = nonce;
await _cache.SetAsync($"{StateCachePrefix}{state}", oidcState, StateExpiration);
var authUrl = oidcService.GetAuthorizationUrl(state, nonce);
return Redirect(authUrl);
}
}
catch (Exception ex)
{
return BadRequest($"Error initiating OpenID Connect flow: {ex.Message}");
}
}
/// <summary>
/// Mobile Apple Sign In endpoint
/// Handles Apple authentication directly from mobile apps
/// </summary>
[HttpPost("apple/mobile")]
public async Task<ActionResult<Models.AuthChallenge>> AppleMobileSignIn(
[FromBody] AppleMobileSignInRequest request)
{
try
{
// Get Apple OIDC service
if (GetOidcService("apple") is not AppleOidcService appleService)
return StatusCode(503, "Apple OIDC service not available");
// Prepare callback data for processing
var callbackData = new OidcCallbackData
{
IdToken = request.IdentityToken,
Code = request.AuthorizationCode,
};
// Process the authentication
var userInfo = await appleService.ProcessCallbackAsync(callbackData);
// Find or create user account using existing logic
var account = await FindOrCreateAccount(userInfo, "apple");
// Create session using the OIDC service
var challenge = await appleService.CreateChallengeForUserAsync(
userInfo,
account,
HttpContext,
request.DeviceId
);
if (challenge == null)
{
return StatusCode(StatusCodes.Status500InternalServerError, "Failed to create authentication challenge");
}
return Ok(challenge);
}
catch (SecurityTokenValidationException ex)
{
return Unauthorized($"Invalid identity token: {ex.Message}");
}
catch (Exception ex)
{
// Log the error
return StatusCode(500, $"Authentication failed: {ex.Message}");
}
}
private async Task<IActionResult> HandleLogin(OidcState oidcState, OidcUserInfo userInfo)
{
try
{
// Find or create the account
var account = await _accountService.FindOrCreateAccountAsync(userInfo, oidcState.Provider ?? throw new InvalidOperationException("Provider not specified"));
if (account == null)
{
_logger.LogError("Failed to find or create account for user {UserId}", userInfo.UserId);
return StatusCode(StatusCodes.Status500InternalServerError, "Failed to process your account");
}
// Create a new session
var session = await _connectionService.CreateSessionAsync(account, oidcState.DeviceId);
if (session == null)
{
_logger.LogError("Failed to create session for account {AccountId}", account.Id);
return StatusCode(StatusCodes.Status500InternalServerError, "Failed to create session");
}
// Create auth tokens
var tokens = await _accountService.GenerateAuthTokensAsync(account, session.Id.ToString());
// Return the tokens and redirect URL
return Ok(new
{
tokens.AccessToken,
tokens.RefreshToken,
tokens.ExpiresIn,
ReturnUrl = oidcState.ReturnUrl ?? "/"
});
}
catch (Exception ex)
{
_logger.LogError(ex, "Error handling OIDC login for user {UserId}", userInfo.UserId);
return StatusCode(StatusCodes.Status500InternalServerError, "An error occurred during login");
}
}
private async Task<IActionResult> HandleAccountConnection(OidcState oidcState, OidcUserInfo userInfo)
{
try
{
// Get the current user's account
if (!Guid.TryParse(oidcState.AccountId, out var accountId))
{
_logger.LogError("Invalid account ID format: {AccountId}", oidcState.AccountId);
return BadRequest("Invalid account ID format");
}
var account = await _accountService.GetAccountByIdAsync(accountId);
if (account == null)
{
_logger.LogError("Account not found for ID {AccountId}", accountId);
return Unauthorized();
}
// Add the OIDC connection to the account
var connection = await _connectionService.AddConnectionAsync(account, userInfo, oidcState.Provider!);
if (connection == null)
{
_logger.LogError("Failed to add OIDC connection for account {AccountId}", account.Id);
return StatusCode(StatusCodes.Status500InternalServerError, "Failed to add OIDC connection");
}
// Return success
return Ok(new { Success = true });
}
catch (Exception ex)
{
_logger.LogError(ex, "Error handling OIDC account connection for user {UserId}", userInfo.UserId);
return StatusCode(StatusCodes.Status500InternalServerError, "An error occurred while connecting your account");
}
}
private OidcService GetOidcService(string provider)
{
return provider.ToLower() switch
{
"apple" => _serviceProvider.GetRequiredService<AppleOidcService>(),
"google" => _serviceProvider.GetRequiredService<GoogleOidcService>(),
"microsoft" => _serviceProvider.GetRequiredService<MicrosoftOidcService>(),
"discord" => _serviceProvider.GetRequiredService<DiscordOidcService>(),
"github" => _serviceProvider.GetRequiredService<GitHubOidcService>(),
"afdian" => _serviceProvider.GetRequiredService<AfdianOidcService>(),
_ => throw new ArgumentException($"Unsupported provider: {provider}")
};
}
private async Task<CommonAccount> FindOrCreateAccount(CommonOidcUserInfo userInfo, string provider)
{
if (string.IsNullOrEmpty(userInfo.Email))
throw new ArgumentException("Email is required for account creation");
// Find or create the account connection
var connection = await _connectionService.FindOrCreateConnection(userInfo, provider);
// If connection already has an account, return it
if (!string.IsNullOrEmpty(connection.AccountId))
{
if (Guid.TryParse(connection.AccountId, out var accountId))
{
var existingAccount = await _accountService.GetAccountByIdAsync(accountId);
if (existingAccount != null)
{
await _connectionService.UpdateConnection(connection, userInfo);
return existingAccount;
}
}
}
// Check if account exists by email
var account = await _accountService.FindByEmailAsync(userInfo.Email);
if (account == null)
{
// Create new account using the account service
account = new CommonAccount
{
Id = Guid.NewGuid().ToString(),
Email = userInfo.Email,
Name = userInfo.Name ?? userInfo.Email,
CreatedAt = SystemClock.Instance.GetCurrentInstant()
};
// Save the new account
account = await _accountService.CreateAccountAsync(account);
}
// Update connection with account ID if needed
if (string.IsNullOrEmpty(connection.AccountId))
{
connection.AccountId = account.Id;
await _connectionService.UpdateConnection(connection, userInfo);
}
return account;
}
}