379 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			C#
		
	
	
	
	
	
			
		
		
	
	
			379 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			C#
		
	
	
	
	
	
using DysonNetwork.Pass.Account;
 | 
						|
using Microsoft.AspNetCore.Authorization;
 | 
						|
using Microsoft.AspNetCore.Mvc;
 | 
						|
using Microsoft.EntityFrameworkCore;
 | 
						|
using DysonNetwork.Shared.Cache;
 | 
						|
using NodaTime;
 | 
						|
using DysonNetwork.Shared.Models;
 | 
						|
 | 
						|
namespace DysonNetwork.Pass.Auth.OpenId;
 | 
						|
 | 
						|
[ApiController]
 | 
						|
[Route("/api/accounts/me/connections")]
 | 
						|
[Authorize]
 | 
						|
public class ConnectionController(
 | 
						|
    AppDatabase db,
 | 
						|
    IEnumerable<OidcService> oidcServices,
 | 
						|
    AccountService accounts,
 | 
						|
    AuthService auth,
 | 
						|
    ICacheService cache,
 | 
						|
    IConfiguration configuration
 | 
						|
) : ControllerBase
 | 
						|
{
 | 
						|
    private const string StateCachePrefix = "oidc-state:";
 | 
						|
    private const string ReturnUrlCachePrefix = "oidc-returning:";
 | 
						|
    private static readonly TimeSpan StateExpiration = TimeSpan.FromMinutes(15);
 | 
						|
 | 
						|
    [HttpGet]
 | 
						|
    public async Task<ActionResult<List<SnAccountConnection>>> GetConnections()
 | 
						|
    {
 | 
						|
        if (HttpContext.Items["CurrentUser"] is not SnAccount currentUser)
 | 
						|
            return Unauthorized();
 | 
						|
 | 
						|
        var connections = await db.AccountConnections
 | 
						|
            .Where(c => c.AccountId == currentUser.Id)
 | 
						|
            .Select(c => new
 | 
						|
            {
 | 
						|
                c.Id,
 | 
						|
                c.AccountId,
 | 
						|
                c.Provider,
 | 
						|
                c.ProvidedIdentifier,
 | 
						|
                c.Meta,
 | 
						|
                c.LastUsedAt,
 | 
						|
                c.CreatedAt,
 | 
						|
                c.UpdatedAt,
 | 
						|
            })
 | 
						|
            .ToListAsync();
 | 
						|
        return Ok(connections);
 | 
						|
    }
 | 
						|
 | 
						|
    [HttpDelete("{id:guid}")]
 | 
						|
    public async Task<ActionResult> RemoveConnection(Guid id)
 | 
						|
    {
 | 
						|
        if (HttpContext.Items["CurrentUser"] is not SnAccount currentUser)
 | 
						|
            return Unauthorized();
 | 
						|
 | 
						|
        var connection = await db.AccountConnections
 | 
						|
            .Where(c => c.Id == id && c.AccountId == currentUser.Id)
 | 
						|
            .FirstOrDefaultAsync();
 | 
						|
        if (connection == null)
 | 
						|
            return NotFound();
 | 
						|
 | 
						|
        db.AccountConnections.Remove(connection);
 | 
						|
        await db.SaveChangesAsync();
 | 
						|
 | 
						|
        return Ok();
 | 
						|
    }
 | 
						|
 | 
						|
    [HttpPost("/api/auth/connect/apple/mobile")]
 | 
						|
    public async Task<ActionResult> ConnectAppleMobile([FromBody] AppleMobileConnectRequest request)
 | 
						|
    {
 | 
						|
        if (HttpContext.Items["CurrentUser"] is not SnAccount currentUser)
 | 
						|
            return Unauthorized();
 | 
						|
 | 
						|
        if (GetOidcService("apple") is not AppleOidcService appleService)
 | 
						|
            return StatusCode(503, "Apple OIDC service not available");
 | 
						|
 | 
						|
        var callbackData = new OidcCallbackData
 | 
						|
        {
 | 
						|
            IdToken = request.IdentityToken,
 | 
						|
            Code = request.AuthorizationCode,
 | 
						|
        };
 | 
						|
 | 
						|
        OidcUserInfo userInfo;
 | 
						|
        try
 | 
						|
        {
 | 
						|
            userInfo = await appleService.ProcessCallbackAsync(callbackData);
 | 
						|
        }
 | 
						|
        catch (Exception ex)
 | 
						|
        {
 | 
						|
            return BadRequest($"Error processing Apple token: {ex.Message}");
 | 
						|
        }
 | 
						|
 | 
						|
        var existingConnection = await db.AccountConnections
 | 
						|
            .FirstOrDefaultAsync(c =>
 | 
						|
                c.Provider == "apple" &&
 | 
						|
                c.ProvidedIdentifier == userInfo.UserId);
 | 
						|
 | 
						|
        if (existingConnection != null)
 | 
						|
        {
 | 
						|
            return BadRequest(
 | 
						|
                $"This Apple account is already linked to {(existingConnection.AccountId == currentUser.Id ? "your account" : "another user")}.");
 | 
						|
        }
 | 
						|
 | 
						|
        db.AccountConnections.Add(new SnAccountConnection
 | 
						|
        {
 | 
						|
            AccountId = currentUser.Id,
 | 
						|
            Provider = "apple",
 | 
						|
            ProvidedIdentifier = userInfo.UserId!,
 | 
						|
            AccessToken = userInfo.AccessToken,
 | 
						|
            RefreshToken = userInfo.RefreshToken,
 | 
						|
            LastUsedAt = SystemClock.Instance.GetCurrentInstant(),
 | 
						|
            Meta = userInfo.ToMetadata(),
 | 
						|
        });
 | 
						|
 | 
						|
        await db.SaveChangesAsync();
 | 
						|
 | 
						|
        return Ok(new { message = "Successfully connected Apple account." });
 | 
						|
    }
 | 
						|
 | 
						|
    private OidcService? GetOidcService(string provider)
 | 
						|
    {
 | 
						|
        return oidcServices.FirstOrDefault(s => s.ProviderName.Equals(provider, StringComparison.OrdinalIgnoreCase));
 | 
						|
    }
 | 
						|
 | 
						|
    public class ConnectProviderRequest
 | 
						|
    {
 | 
						|
        public string Provider { get; set; } = null!;
 | 
						|
        public string? ReturnUrl { get; set; }
 | 
						|
    }
 | 
						|
 | 
						|
    [AllowAnonymous]
 | 
						|
    [Route("/api/auth/callback/{provider}")]
 | 
						|
    [HttpGet, HttpPost]
 | 
						|
    public async Task<IActionResult> HandleCallback([FromRoute] string provider)
 | 
						|
    {
 | 
						|
        var oidcService = GetOidcService(provider);
 | 
						|
        if (oidcService == null)
 | 
						|
            return BadRequest($"Provider '{provider}' is not supported.");
 | 
						|
 | 
						|
        var callbackData = await ExtractCallbackData(Request);
 | 
						|
        if (callbackData.State == null)
 | 
						|
            return BadRequest("State parameter is missing.");
 | 
						|
 | 
						|
        // Get the state from the cache
 | 
						|
        var stateKey = $"{StateCachePrefix}{callbackData.State}";
 | 
						|
 | 
						|
        // Try to get the state as OidcState first (new format)
 | 
						|
        var oidcState = await cache.GetAsync<OidcState>(stateKey);
 | 
						|
 | 
						|
        // If not found, try to get as string (legacy format)
 | 
						|
        if (oidcState == null)
 | 
						|
        {
 | 
						|
            var stateValue = await cache.GetAsync<string>(stateKey);
 | 
						|
            if (string.IsNullOrEmpty(stateValue) || !OidcState.TryParse(stateValue, out oidcState) || oidcState == null)
 | 
						|
                return BadRequest("Invalid or expired state parameter");
 | 
						|
        }
 | 
						|
 | 
						|
        // Remove the state from cache to prevent replay attacks
 | 
						|
        await cache.RemoveAsync(stateKey);
 | 
						|
 | 
						|
        // Handle the flow based on state type
 | 
						|
        if (oidcState is { FlowType: OidcFlowType.Connect, AccountId: not null })
 | 
						|
        {
 | 
						|
            // Connection flow
 | 
						|
            if (oidcState.DeviceId != null)
 | 
						|
            {
 | 
						|
                callbackData.State = oidcState.DeviceId;
 | 
						|
            }
 | 
						|
            return await HandleManualConnection(provider, oidcService, callbackData, oidcState.AccountId.Value);
 | 
						|
        }
 | 
						|
        else if (oidcState.FlowType == OidcFlowType.Login)
 | 
						|
        {
 | 
						|
            // Login/Registration flow
 | 
						|
            if (!string.IsNullOrEmpty(oidcState.DeviceId))
 | 
						|
            {
 | 
						|
                callbackData.State = oidcState.DeviceId;
 | 
						|
            }
 | 
						|
 | 
						|
            // Store return URL if provided
 | 
						|
            if (string.IsNullOrEmpty(oidcState.ReturnUrl) || oidcState.ReturnUrl == "/")
 | 
						|
                return await HandleLoginOrRegistration(provider, oidcService, callbackData);
 | 
						|
            var returnUrlKey = $"{ReturnUrlCachePrefix}{callbackData.State}";
 | 
						|
            await cache.SetAsync(returnUrlKey, oidcState.ReturnUrl, StateExpiration);
 | 
						|
 | 
						|
            return await HandleLoginOrRegistration(provider, oidcService, callbackData);
 | 
						|
        }
 | 
						|
 | 
						|
        return BadRequest("Unsupported flow type");
 | 
						|
    }
 | 
						|
 | 
						|
    private async Task<IActionResult> HandleManualConnection(
 | 
						|
        string provider,
 | 
						|
        OidcService oidcService,
 | 
						|
        OidcCallbackData callbackData,
 | 
						|
        Guid accountId
 | 
						|
    )
 | 
						|
    {
 | 
						|
        provider = provider.ToLower();
 | 
						|
 | 
						|
        OidcUserInfo userInfo;
 | 
						|
        try
 | 
						|
        {
 | 
						|
            userInfo = await oidcService.ProcessCallbackAsync(callbackData);
 | 
						|
        }
 | 
						|
        catch (Exception ex)
 | 
						|
        {
 | 
						|
            return BadRequest($"Error processing {provider} authentication: {ex.Message}");
 | 
						|
        }
 | 
						|
 | 
						|
        if (string.IsNullOrEmpty(userInfo.UserId))
 | 
						|
        {
 | 
						|
            return BadRequest($"{provider} did not return a valid user identifier.");
 | 
						|
        }
 | 
						|
 | 
						|
        // Extract device ID from the callback state if available
 | 
						|
        var deviceId = !string.IsNullOrEmpty(callbackData.State) ? callbackData.State : string.Empty;
 | 
						|
 | 
						|
        // Check if this provider account is already connected to any user
 | 
						|
        var existingConnection = await db.AccountConnections
 | 
						|
            .FirstOrDefaultAsync(c =>
 | 
						|
                c.Provider == provider &&
 | 
						|
                c.ProvidedIdentifier == userInfo.UserId);
 | 
						|
 | 
						|
        // If it's connected to a different user, return error
 | 
						|
        if (existingConnection != null && existingConnection.AccountId != accountId)
 | 
						|
        {
 | 
						|
            return BadRequest($"This {provider} account is already linked to another user.");
 | 
						|
        }
 | 
						|
 | 
						|
        // Check if the current user already has this provider connected
 | 
						|
        var userHasProvider = await db.AccountConnections
 | 
						|
            .AnyAsync(c =>
 | 
						|
                c.AccountId == accountId &&
 | 
						|
                c.Provider == provider);
 | 
						|
 | 
						|
        if (userHasProvider)
 | 
						|
        {
 | 
						|
            // Update existing connection with new tokens
 | 
						|
            var connection = await db.AccountConnections
 | 
						|
                .FirstOrDefaultAsync(c =>
 | 
						|
                    c.AccountId == accountId &&
 | 
						|
                    c.Provider == provider);
 | 
						|
 | 
						|
            if (connection != null)
 | 
						|
            {
 | 
						|
                connection.AccessToken = userInfo.AccessToken;
 | 
						|
                connection.RefreshToken = userInfo.RefreshToken;
 | 
						|
                connection.LastUsedAt = SystemClock.Instance.GetCurrentInstant();
 | 
						|
                connection.Meta = userInfo.ToMetadata();
 | 
						|
            }
 | 
						|
        }
 | 
						|
        else
 | 
						|
        {
 | 
						|
            // Create new connection
 | 
						|
            db.AccountConnections.Add(new SnAccountConnection
 | 
						|
            {
 | 
						|
                AccountId = accountId,
 | 
						|
                Provider = provider,
 | 
						|
                ProvidedIdentifier = userInfo.UserId!,
 | 
						|
                AccessToken = userInfo.AccessToken,
 | 
						|
                RefreshToken = userInfo.RefreshToken,
 | 
						|
                LastUsedAt = SystemClock.Instance.GetCurrentInstant(),
 | 
						|
                Meta = userInfo.ToMetadata(),
 | 
						|
            });
 | 
						|
        }
 | 
						|
 | 
						|
        try
 | 
						|
        {
 | 
						|
            await db.SaveChangesAsync();
 | 
						|
        }
 | 
						|
        catch (DbUpdateException)
 | 
						|
        {
 | 
						|
            return StatusCode(500, $"Failed to save {provider} connection. Please try again.");
 | 
						|
        }
 | 
						|
 | 
						|
        // Clean up and redirect
 | 
						|
        var returnUrlKey = $"{ReturnUrlCachePrefix}{callbackData.State}";
 | 
						|
        var returnUrl = await cache.GetAsync<string>(returnUrlKey);
 | 
						|
        await cache.RemoveAsync(returnUrlKey);
 | 
						|
 | 
						|
        var siteUrl = configuration["SiteUrl"];
 | 
						|
 | 
						|
        return Redirect(string.IsNullOrEmpty(returnUrl) ? siteUrl + "/auth/callback" : returnUrl);
 | 
						|
    }
 | 
						|
 | 
						|
    private async Task<IActionResult> HandleLoginOrRegistration(
 | 
						|
        string provider,
 | 
						|
        OidcService oidcService,
 | 
						|
        OidcCallbackData callbackData
 | 
						|
    )
 | 
						|
    {
 | 
						|
        OidcUserInfo userInfo;
 | 
						|
        try
 | 
						|
        {
 | 
						|
            userInfo = await oidcService.ProcessCallbackAsync(callbackData);
 | 
						|
        }
 | 
						|
        catch (Exception ex)
 | 
						|
        {
 | 
						|
            return BadRequest($"Error processing callback: {ex.Message}");
 | 
						|
        }
 | 
						|
 | 
						|
        if (string.IsNullOrEmpty(userInfo.Email) || string.IsNullOrEmpty(userInfo.UserId))
 | 
						|
        {
 | 
						|
            return BadRequest($"Email or user ID is missing from {provider}'s response");
 | 
						|
        }
 | 
						|
 | 
						|
        var connection = await db.AccountConnections
 | 
						|
            .Include(c => c.Account)
 | 
						|
            .FirstOrDefaultAsync(c => c.Provider == provider && c.ProvidedIdentifier == userInfo.UserId);
 | 
						|
 | 
						|
        var clock = SystemClock.Instance;
 | 
						|
        if (connection != null)
 | 
						|
        {
 | 
						|
            // Login existing user
 | 
						|
            var deviceId = !string.IsNullOrEmpty(callbackData.State) ?
 | 
						|
                callbackData.State.Split('|').FirstOrDefault() :
 | 
						|
                string.Empty;
 | 
						|
 | 
						|
            var challenge = await oidcService.CreateChallengeForUserAsync(
 | 
						|
                userInfo,
 | 
						|
                connection.Account,
 | 
						|
                HttpContext,
 | 
						|
                deviceId ?? string.Empty);
 | 
						|
            return Redirect($"/auth/callback?challenge={challenge.Id}");
 | 
						|
        }
 | 
						|
 | 
						|
        // Register new user
 | 
						|
        var account = await accounts.LookupAccount(userInfo.Email) ?? await accounts.CreateAccount(userInfo);
 | 
						|
 | 
						|
        // Create connection for new or existing user
 | 
						|
        var newConnection = new SnAccountConnection
 | 
						|
        {
 | 
						|
            Account = account,
 | 
						|
            Provider = provider,
 | 
						|
            ProvidedIdentifier = userInfo.UserId!,
 | 
						|
            AccessToken = userInfo.AccessToken,
 | 
						|
            RefreshToken = userInfo.RefreshToken,
 | 
						|
            LastUsedAt = clock.GetCurrentInstant(),
 | 
						|
            Meta = userInfo.ToMetadata()
 | 
						|
        };
 | 
						|
        db.AccountConnections.Add(newConnection);
 | 
						|
 | 
						|
        await db.SaveChangesAsync();
 | 
						|
 | 
						|
        var loginSession = await auth.CreateSessionForOidcAsync(account, clock.GetCurrentInstant());
 | 
						|
        var loginToken = auth.CreateToken(loginSession);
 | 
						|
 | 
						|
        var siteUrl = configuration["SiteUrl"];
 | 
						|
 | 
						|
        return Redirect(siteUrl + $"/auth/callback?token={loginToken}");
 | 
						|
    }
 | 
						|
 | 
						|
    private static async Task<OidcCallbackData> ExtractCallbackData(HttpRequest request)
 | 
						|
    {
 | 
						|
        var data = new OidcCallbackData();
 | 
						|
        switch (request.Method)
 | 
						|
        {
 | 
						|
            case "GET":
 | 
						|
                data.Code = Uri.UnescapeDataString(request.Query["code"].FirstOrDefault() ?? "");
 | 
						|
                data.IdToken = Uri.UnescapeDataString(request.Query["id_token"].FirstOrDefault() ?? "");
 | 
						|
                data.State = Uri.UnescapeDataString(request.Query["state"].FirstOrDefault() ?? "");
 | 
						|
                break;
 | 
						|
            case "POST" when request.HasFormContentType:
 | 
						|
                {
 | 
						|
                    var form = await request.ReadFormAsync();
 | 
						|
                    data.Code = Uri.UnescapeDataString(form["code"].FirstOrDefault() ?? "");
 | 
						|
                    data.IdToken = Uri.UnescapeDataString(form["id_token"].FirstOrDefault() ?? "");
 | 
						|
                    data.State = Uri.UnescapeDataString(form["state"].FirstOrDefault() ?? "");
 | 
						|
                    if (form.ContainsKey("user"))
 | 
						|
                        data.RawData = Uri.UnescapeDataString(form["user"].FirstOrDefault() ?? "");
 | 
						|
 | 
						|
                    break;
 | 
						|
                }
 | 
						|
        }
 | 
						|
 | 
						|
        return data;
 | 
						|
    }
 | 
						|
}
 |