370 lines
13 KiB
C#
370 lines
13 KiB
C#
using DysonNetwork.Sphere.Account;
|
|
using Microsoft.AspNetCore.Authorization;
|
|
using Microsoft.AspNetCore.Mvc;
|
|
using Microsoft.EntityFrameworkCore;
|
|
using DysonNetwork.Sphere.Storage;
|
|
using NodaTime;
|
|
|
|
namespace DysonNetwork.Sphere.Auth.OpenId;
|
|
|
|
[ApiController]
|
|
[Route("/accounts/me/connections")]
|
|
[Authorize]
|
|
public class ConnectionController(
|
|
AppDatabase db,
|
|
IEnumerable<OidcService> oidcServices,
|
|
AccountService accounts,
|
|
AuthService auth,
|
|
ICacheService cache
|
|
) : ControllerBase
|
|
{
|
|
private const string StateCachePrefix = "oidc-state:";
|
|
private const string ReturnUrlCachePrefix = "oidc-returning:";
|
|
private static readonly TimeSpan StateExpiration = TimeSpan.FromMinutes(15);
|
|
|
|
[HttpGet]
|
|
public async Task<ActionResult<List<AccountConnection>>> GetConnections()
|
|
{
|
|
if (HttpContext.Items["CurrentUser"] is not Account.Account currentUser)
|
|
return Unauthorized();
|
|
|
|
var connections = await db.AccountConnections
|
|
.Where(c => c.AccountId == currentUser.Id)
|
|
.Select(c => new
|
|
{
|
|
c.Id,
|
|
c.AccountId,
|
|
c.Provider,
|
|
c.ProvidedIdentifier,
|
|
c.Meta,
|
|
c.LastUsedAt,
|
|
c.CreatedAt,
|
|
c.UpdatedAt,
|
|
})
|
|
.ToListAsync();
|
|
return Ok(connections);
|
|
}
|
|
|
|
[HttpDelete("{id:guid}")]
|
|
public async Task<ActionResult> RemoveConnection(Guid id)
|
|
{
|
|
if (HttpContext.Items["CurrentUser"] is not Account.Account currentUser)
|
|
return Unauthorized();
|
|
|
|
var connection = await db.AccountConnections
|
|
.Where(c => c.Id == id && c.AccountId == currentUser.Id)
|
|
.FirstOrDefaultAsync();
|
|
if (connection == null)
|
|
return NotFound();
|
|
|
|
db.AccountConnections.Remove(connection);
|
|
await db.SaveChangesAsync();
|
|
|
|
return Ok();
|
|
}
|
|
|
|
[HttpPost("/auth/connect/apple/mobile")]
|
|
public async Task<ActionResult> ConnectAppleMobile([FromBody] AppleMobileConnectRequest request)
|
|
{
|
|
if (HttpContext.Items["CurrentUser"] is not Account.Account currentUser)
|
|
return Unauthorized();
|
|
|
|
if (GetOidcService("apple") is not AppleOidcService appleService)
|
|
return StatusCode(503, "Apple OIDC service not available");
|
|
|
|
var callbackData = new OidcCallbackData
|
|
{
|
|
IdToken = request.IdentityToken,
|
|
Code = request.AuthorizationCode,
|
|
};
|
|
|
|
OidcUserInfo userInfo;
|
|
try
|
|
{
|
|
userInfo = await appleService.ProcessCallbackAsync(callbackData);
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
return BadRequest($"Error processing Apple token: {ex.Message}");
|
|
}
|
|
|
|
var existingConnection = await db.AccountConnections
|
|
.FirstOrDefaultAsync(c =>
|
|
c.Provider == "apple" &&
|
|
c.ProvidedIdentifier == userInfo.UserId);
|
|
|
|
if (existingConnection != null)
|
|
{
|
|
return BadRequest(
|
|
$"This Apple account is already linked to {(existingConnection.AccountId == currentUser.Id ? "your account" : "another user")}.");
|
|
}
|
|
|
|
db.AccountConnections.Add(new AccountConnection
|
|
{
|
|
AccountId = currentUser.Id,
|
|
Provider = "apple",
|
|
ProvidedIdentifier = userInfo.UserId!,
|
|
AccessToken = userInfo.AccessToken,
|
|
RefreshToken = userInfo.RefreshToken,
|
|
LastUsedAt = SystemClock.Instance.GetCurrentInstant(),
|
|
Meta = userInfo.ToMetadata(),
|
|
});
|
|
|
|
await db.SaveChangesAsync();
|
|
|
|
return Ok(new { message = "Successfully connected Apple account." });
|
|
}
|
|
|
|
private OidcService? GetOidcService(string provider)
|
|
{
|
|
return oidcServices.FirstOrDefault(s => s.ProviderName.Equals(provider, StringComparison.OrdinalIgnoreCase));
|
|
}
|
|
|
|
public class ConnectProviderRequest
|
|
{
|
|
public string Provider { get; set; } = null!;
|
|
public string? ReturnUrl { get; set; }
|
|
}
|
|
|
|
/// <summary>
|
|
/// Initiates manual connection to an OAuth provider for the current user
|
|
/// </summary>
|
|
[HttpPost("connect")]
|
|
public async Task<ActionResult<object>> InitiateConnection([FromBody] ConnectProviderRequest request)
|
|
{
|
|
if (HttpContext.Items["CurrentUser"] is not Account.Account currentUser)
|
|
return Unauthorized();
|
|
|
|
var oidcService = GetOidcService(request.Provider);
|
|
if (oidcService == null)
|
|
return BadRequest($"Provider '{request.Provider}' is not supported");
|
|
|
|
var existingConnection = await db.AccountConnections
|
|
.AnyAsync(c => c.AccountId == currentUser.Id && c.Provider == oidcService.ProviderName);
|
|
|
|
if (existingConnection)
|
|
return BadRequest($"You already have a {request.Provider} connection");
|
|
|
|
var state = Guid.NewGuid().ToString("N");
|
|
var nonce = Guid.NewGuid().ToString("N");
|
|
var stateValue = $"{currentUser.Id}|{request.Provider}|{nonce}";
|
|
var finalReturnUrl = !string.IsNullOrEmpty(request.ReturnUrl) ? request.ReturnUrl : "/settings/connections";
|
|
|
|
// Store state and return URL in cache
|
|
await cache.SetAsync($"{StateCachePrefix}{state}", stateValue, StateExpiration);
|
|
await cache.SetAsync($"{ReturnUrlCachePrefix}{state}", finalReturnUrl, StateExpiration);
|
|
|
|
var authUrl = oidcService.GetAuthorizationUrl(state, nonce);
|
|
|
|
return Ok(new
|
|
{
|
|
authUrl,
|
|
message = $"Redirect to this URL to connect your {request.Provider} account"
|
|
});
|
|
}
|
|
|
|
[AllowAnonymous]
|
|
[Route("/auth/callback/{provider}")]
|
|
[HttpGet, HttpPost]
|
|
public async Task<IActionResult> HandleCallback([FromRoute] string provider)
|
|
{
|
|
var oidcService = GetOidcService(provider);
|
|
if (oidcService == null)
|
|
return BadRequest($"Provider '{provider}' is not supported.");
|
|
|
|
var callbackData = await ExtractCallbackData(Request);
|
|
if (callbackData.State == null)
|
|
return BadRequest("State parameter is missing.");
|
|
|
|
// Get and validate state from cache
|
|
var stateKey = $"{StateCachePrefix}{callbackData.State}";
|
|
var stateValue = await cache.GetAsync<string>(stateKey);
|
|
if (string.IsNullOrEmpty(stateValue))
|
|
return BadRequest("Invalid or expired state parameter");
|
|
|
|
// Remove state from cache to prevent replay attacks
|
|
await cache.RemoveAsync(stateKey);
|
|
|
|
var stateParts = stateValue.Split('|');
|
|
if (stateParts.Length != 3)
|
|
{
|
|
return BadRequest("Invalid state format");
|
|
}
|
|
|
|
var accountId = Guid.Parse(stateParts[0]);
|
|
return await HandleManualConnection(provider, oidcService, callbackData, accountId);
|
|
}
|
|
|
|
private async Task<IActionResult> HandleManualConnection(
|
|
string provider,
|
|
OidcService oidcService,
|
|
OidcCallbackData callbackData,
|
|
Guid accountId
|
|
)
|
|
{
|
|
provider = provider.ToLower();
|
|
|
|
OidcUserInfo userInfo;
|
|
try
|
|
{
|
|
userInfo = await oidcService.ProcessCallbackAsync(callbackData);
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
return BadRequest($"Error processing {provider} authentication: {ex.Message}");
|
|
}
|
|
|
|
if (string.IsNullOrEmpty(userInfo.UserId))
|
|
{
|
|
return BadRequest($"{provider} did not return a valid user identifier.");
|
|
}
|
|
|
|
// Check if this provider account is already connected to any user
|
|
var existingConnection = await db.AccountConnections
|
|
.FirstOrDefaultAsync(c =>
|
|
c.Provider == provider &&
|
|
c.ProvidedIdentifier == userInfo.UserId);
|
|
|
|
// If it's connected to a different user, return error
|
|
if (existingConnection != null && existingConnection.AccountId != accountId)
|
|
{
|
|
return BadRequest($"This {provider} account is already linked to another user.");
|
|
}
|
|
|
|
// Check if the current user already has this provider connected
|
|
var userHasProvider = await db.AccountConnections
|
|
.AnyAsync(c =>
|
|
c.AccountId == accountId &&
|
|
c.Provider == provider);
|
|
|
|
if (userHasProvider)
|
|
{
|
|
// Update existing connection with new tokens
|
|
var connection = await db.AccountConnections
|
|
.FirstOrDefaultAsync(c =>
|
|
c.AccountId == accountId &&
|
|
c.Provider == provider);
|
|
|
|
if (connection != null)
|
|
{
|
|
connection.AccessToken = userInfo.AccessToken;
|
|
connection.RefreshToken = userInfo.RefreshToken;
|
|
connection.LastUsedAt = SystemClock.Instance.GetCurrentInstant();
|
|
connection.Meta = userInfo.ToMetadata();
|
|
}
|
|
}
|
|
else
|
|
{
|
|
// Create new connection
|
|
db.AccountConnections.Add(new AccountConnection
|
|
{
|
|
AccountId = accountId,
|
|
Provider = provider,
|
|
ProvidedIdentifier = userInfo.UserId!,
|
|
AccessToken = userInfo.AccessToken,
|
|
RefreshToken = userInfo.RefreshToken,
|
|
LastUsedAt = SystemClock.Instance.GetCurrentInstant(),
|
|
Meta = userInfo.ToMetadata(),
|
|
});
|
|
}
|
|
|
|
try
|
|
{
|
|
await db.SaveChangesAsync();
|
|
}
|
|
catch (DbUpdateException ex)
|
|
{
|
|
return StatusCode(500, $"Failed to save {provider} connection. Please try again.");
|
|
}
|
|
|
|
// Clean up and redirect
|
|
var returnUrlKey = $"{ReturnUrlCachePrefix}{callbackData.State}";
|
|
var returnUrl = await cache.GetAsync<string>(returnUrlKey);
|
|
await cache.RemoveAsync(returnUrlKey);
|
|
|
|
return Redirect(string.IsNullOrEmpty(returnUrl) ? "/settings/connections" : returnUrl);
|
|
}
|
|
|
|
private async Task<IActionResult> HandleLoginOrRegistration(
|
|
string provider,
|
|
OidcService oidcService,
|
|
OidcCallbackData callbackData
|
|
)
|
|
{
|
|
OidcUserInfo userInfo;
|
|
try
|
|
{
|
|
userInfo = await oidcService.ProcessCallbackAsync(callbackData);
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
return BadRequest($"Error processing callback: {ex.Message}");
|
|
}
|
|
|
|
if (string.IsNullOrEmpty(userInfo.Email) || string.IsNullOrEmpty(userInfo.UserId))
|
|
{
|
|
return BadRequest($"Email or user ID is missing from {provider}'s response");
|
|
}
|
|
|
|
var connection = await db.AccountConnections
|
|
.Include(c => c.Account)
|
|
.FirstOrDefaultAsync(c => c.Provider == provider && c.ProvidedIdentifier == userInfo.UserId);
|
|
|
|
var clock = SystemClock.Instance;
|
|
if (connection != null)
|
|
{
|
|
// Login existing user
|
|
var session = await auth.CreateSessionAsync(connection.Account, clock.GetCurrentInstant());
|
|
var token = auth.CreateToken(session);
|
|
return Redirect($"/auth/token?token={token}");
|
|
}
|
|
|
|
// Register new user
|
|
var account = await accounts.LookupAccount(userInfo.Email) ?? await accounts.CreateAccount(userInfo);
|
|
|
|
// Create connection for new or existing user
|
|
var newConnection = new AccountConnection
|
|
{
|
|
Account = account,
|
|
Provider = provider,
|
|
ProvidedIdentifier = userInfo.UserId!,
|
|
AccessToken = userInfo.AccessToken,
|
|
RefreshToken = userInfo.RefreshToken,
|
|
LastUsedAt = clock.GetCurrentInstant(),
|
|
Meta = userInfo.ToMetadata()
|
|
};
|
|
db.AccountConnections.Add(newConnection);
|
|
|
|
await db.SaveChangesAsync();
|
|
|
|
var loginSession = await auth.CreateSessionAsync(account, clock.GetCurrentInstant());
|
|
var loginToken = auth.CreateToken(loginSession);
|
|
return Redirect($"/auth/token?token={loginToken}");
|
|
}
|
|
|
|
private static async Task<OidcCallbackData> ExtractCallbackData(HttpRequest request)
|
|
{
|
|
var data = new OidcCallbackData();
|
|
switch (request.Method)
|
|
{
|
|
case "GET":
|
|
data.Code = Uri.UnescapeDataString(request.Query["code"].FirstOrDefault() ?? "");
|
|
data.IdToken = Uri.UnescapeDataString(request.Query["id_token"].FirstOrDefault() ?? "");
|
|
data.State = Uri.UnescapeDataString(request.Query["state"].FirstOrDefault() ?? "");
|
|
break;
|
|
case "POST" when request.HasFormContentType:
|
|
{
|
|
var form = await request.ReadFormAsync();
|
|
data.Code = Uri.UnescapeDataString(form["code"].FirstOrDefault() ?? "");
|
|
data.IdToken = Uri.UnescapeDataString(form["id_token"].FirstOrDefault() ?? "");
|
|
data.State = Uri.UnescapeDataString(form["state"].FirstOrDefault() ?? "");
|
|
if (form.ContainsKey("user"))
|
|
data.RawData = Uri.UnescapeDataString(form["user"].FirstOrDefault() ?? "");
|
|
|
|
break;
|
|
}
|
|
}
|
|
|
|
return data;
|
|
}
|
|
} |