Swarm/DysonNetwork.Sphere/Auth/OpenId/ConnectionController.cs

261 lines
9.5 KiB
C#

using DysonNetwork.Sphere.Account;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
using Microsoft.EntityFrameworkCore;
using NodaTime;
namespace DysonNetwork.Sphere.Auth.OpenId;
[ApiController]
[Route("/api/accounts/me/connections")]
[Authorize]
public class ConnectionController(
AppDatabase db,
IEnumerable<OidcService> oidcServices,
AccountService accounts,
AuthService auth
) : ControllerBase
{
[HttpGet]
public async Task<ActionResult<List<AccountConnection>>> GetConnections()
{
if (HttpContext.Items["CurrentUser"] is not Account.Account currentUser)
return Unauthorized();
var connections = await db.AccountConnections
.Where(c => c.AccountId == currentUser.Id)
.Select(c => new { c.Id, c.AccountId, c.Provider, c.ProvidedIdentifier, c.Meta, c.LastUsedAt })
.ToListAsync();
return Ok(connections);
}
[HttpDelete("{id:guid}")]
public async Task<ActionResult> RemoveConnection(Guid id)
{
if (HttpContext.Items["CurrentUser"] is not Account.Account currentUser)
return Unauthorized();
var connection = await db.AccountConnections
.Where(c => c.Id == id && c.AccountId == currentUser.Id)
.FirstOrDefaultAsync();
if (connection == null)
return NotFound();
db.AccountConnections.Remove(connection);
await db.SaveChangesAsync();
return Ok();
}
public class ConnectProviderRequest
{
public string Provider { get; set; } = null!;
public string? ReturnUrl { get; set; }
}
/// <summary>
/// Initiates manual connection to an OAuth provider for the current user
/// </summary>
[HttpPost("connect")]
public async Task<ActionResult<object>> InitiateConnection([FromBody] ConnectProviderRequest request)
{
if (HttpContext.Items["CurrentUser"] is not Account.Account currentUser)
return Unauthorized();
var oidcService = oidcServices.FirstOrDefault(s =>
s.ProviderName.Equals(request.Provider, StringComparison.OrdinalIgnoreCase));
if (oidcService == null)
return BadRequest($"Provider '{request.Provider}' is not supported");
var existingConnection = await db.AccountConnections
.AnyAsync(c => c.AccountId == currentUser.Id && c.Provider == oidcService.ProviderName);
if (existingConnection)
return BadRequest($"You already have a {request.Provider} connection");
var state = Guid.NewGuid().ToString("N");
var nonce = Guid.NewGuid().ToString("N");
HttpContext.Session.SetString($"oidc_state_{state}", $"{currentUser.Id}|{request.Provider}|{nonce}");
var finalReturnUrl = !string.IsNullOrEmpty(request.ReturnUrl) ? request.ReturnUrl : "/settings/connections";
HttpContext.Session.SetString($"oidc_return_url_{state}", finalReturnUrl);
var authUrl = oidcService.GetAuthorizationUrl(state, nonce);
return Ok(new
{
authUrl,
message = $"Redirect to this URL to connect your {request.Provider} account"
});
}
[AllowAnonymous]
[Route("/auth/callback/{provider}")]
[HttpGet, HttpPost]
public async Task<IActionResult> HandleCallback([FromRoute] string provider)
{
var oidcService =
oidcServices.FirstOrDefault(s => s.ProviderName.Equals(provider, StringComparison.OrdinalIgnoreCase));
if (oidcService == null)
return BadRequest($"Provider '{provider}' is not supported.");
var callbackData = await ExtractCallbackData(Request);
if (callbackData.State == null)
return BadRequest("State parameter is missing.");
var sessionState = HttpContext.Session.GetString($"oidc_state_{callbackData.State!}");
HttpContext.Session.Remove($"oidc_state_{callbackData.State}");
// If sessionState is present, it's a manual connection flow for an existing user.
if (sessionState == null) return await HandleLoginOrRegistration(provider, oidcService, callbackData);
var stateParts = sessionState.Split('|');
if (stateParts.Length != 3 || !stateParts[1].Equals(provider, StringComparison.OrdinalIgnoreCase))
return BadRequest("State mismatch.");
var accountId = Guid.Parse(stateParts[0]);
return await HandleManualConnection(provider, oidcService, callbackData, accountId);
// Otherwise, it's a login or registration flow.
}
private async Task<IActionResult> HandleManualConnection(string provider, OidcService oidcService,
OidcCallbackData callbackData, Guid accountId)
{
OidcUserInfo userInfo;
try
{
userInfo = await oidcService.ProcessCallbackAsync(callbackData);
}
catch (Exception ex)
{
return BadRequest($"Error processing callback: {ex.Message}");
}
var existingConnection = await db.AccountConnections
.FirstOrDefaultAsync(c =>
c.Provider.Equals(provider, StringComparison.OrdinalIgnoreCase) &&
c.ProvidedIdentifier == userInfo.UserId);
if (existingConnection != null && existingConnection.AccountId != accountId)
{
return BadRequest($"This {provider} account is already linked to another user.");
}
var userConnection = await db.AccountConnections
.FirstOrDefaultAsync(c =>
c.AccountId == accountId && c.Provider.Equals(provider, StringComparison.OrdinalIgnoreCase));
var clock = SystemClock.Instance;
if (userConnection != null)
{
userConnection.AccessToken = userInfo.AccessToken;
userConnection.RefreshToken = userInfo.RefreshToken;
userConnection.LastUsedAt = clock.GetCurrentInstant();
}
else
{
db.AccountConnections.Add(new AccountConnection
{
AccountId = accountId,
Provider = provider,
ProvidedIdentifier = userInfo.UserId!,
AccessToken = userInfo.AccessToken,
RefreshToken = userInfo.RefreshToken,
LastUsedAt = clock.GetCurrentInstant(),
Meta = userInfo.ToMetadata(),
});
}
await db.SaveChangesAsync();
var returnUrl = HttpContext.Session.GetString($"oidc_return_url_{callbackData.State}");
HttpContext.Session.Remove($"oidc_return_url_{callbackData.State}");
return Redirect(string.IsNullOrEmpty(returnUrl) ? "/" : returnUrl);
}
private async Task<IActionResult> HandleLoginOrRegistration(
string provider,
OidcService oidcService,
OidcCallbackData callbackData
)
{
OidcUserInfo userInfo;
try
{
userInfo = await oidcService.ProcessCallbackAsync(callbackData);
}
catch (Exception ex)
{
return BadRequest($"Error processing callback: {ex.Message}");
}
if (string.IsNullOrEmpty(userInfo.Email) || string.IsNullOrEmpty(userInfo.UserId))
{
return BadRequest($"Email or user ID is missing from {provider}'s response");
}
var connection = await db.AccountConnections
.Include(c => c.Account)
.FirstOrDefaultAsync(c => c.Provider == provider && c.ProvidedIdentifier == userInfo.UserId);
var clock = SystemClock.Instance;
if (connection != null)
{
// Login existing user
var session = await auth.CreateSessionAsync(connection.Account, clock.GetCurrentInstant());
var token = auth.CreateToken(session);
return Redirect($"/auth/token?token={token}");
}
// Register new user
var account = await accounts.LookupAccount(userInfo.Email) ?? await accounts.CreateAccount(userInfo);
// Create connection for new or existing user
var newConnection = new AccountConnection
{
Account = account,
Provider = provider,
ProvidedIdentifier = userInfo.UserId!,
AccessToken = userInfo.AccessToken,
RefreshToken = userInfo.RefreshToken,
LastUsedAt = clock.GetCurrentInstant(),
Meta = userInfo.ToMetadata()
};
db.AccountConnections.Add(newConnection);
await db.SaveChangesAsync();
var loginSession = await auth.CreateSessionAsync(account, clock.GetCurrentInstant());
var loginToken = auth.CreateToken(loginSession);
return Redirect($"/auth/token?token={loginToken}");
}
private static async Task<OidcCallbackData> ExtractCallbackData(HttpRequest request)
{
var data = new OidcCallbackData();
switch (request.Method)
{
case "GET":
data.Code = request.Query["code"].FirstOrDefault() ?? "";
data.IdToken = request.Query["id_token"].FirstOrDefault() ?? "";
data.State = request.Query["state"].FirstOrDefault();
break;
case "POST" when request.HasFormContentType:
{
var form = await request.ReadFormAsync();
data.Code = form["code"].FirstOrDefault() ?? "";
data.IdToken = form["id_token"].FirstOrDefault() ?? "";
data.State = form["state"].FirstOrDefault();
if (form.ContainsKey("user"))
{
data.RawData = form["user"].FirstOrDefault();
}
break;
}
}
return data;
}
}