Enrich the user connections

This commit is contained in:
LittleSheep 2025-06-15 23:49:26 +08:00
parent 44ff09c119
commit 90eca43284
8 changed files with 3772 additions and 74 deletions

View File

@ -180,9 +180,10 @@ public class AccountConnection : ModelBase
public Guid Id { get; set; } = Guid.NewGuid(); public Guid Id { get; set; } = Guid.NewGuid();
[MaxLength(4096)] public string Provider { get; set; } = null!; [MaxLength(4096)] public string Provider { get; set; } = null!;
[MaxLength(8192)] public string ProvidedIdentifier { get; set; } = null!; [MaxLength(8192)] public string ProvidedIdentifier { get; set; } = null!;
[Column(TypeName = "jsonb")] public Dictionary<string, object>? Meta { get; set; } = new();
[MaxLength(4096)] public string? AccessToken { get; set; } [JsonIgnore] [MaxLength(4096)] public string? AccessToken { get; set; }
[MaxLength(4096)] public string? RefreshToken { get; set; } [JsonIgnore] [MaxLength(4096)] public string? RefreshToken { get; set; }
public Instant? LastUsedAt { get; set; } public Instant? LastUsedAt { get; set; }
public Guid AccountId { get; set; } public Guid AccountId { get; set; }

View File

@ -1,15 +0,0 @@
using DysonNetwork.Sphere.Account;
using Microsoft.AspNetCore.Mvc;
using Microsoft.EntityFrameworkCore;
using NodaTime;
namespace DysonNetwork.Sphere.Auth.OpenId;
/// <summary>
/// This controller is designed to handle the OAuth callback.
/// </summary>
[ApiController]
[Route("/auth/callback")]
public class AuthCallbackController : ControllerBase
{
}

View File

@ -10,11 +10,10 @@ namespace DysonNetwork.Sphere.Auth.OpenId;
[Route("/api/accounts/me/connections")] [Route("/api/accounts/me/connections")]
[Authorize] [Authorize]
public class ConnectionController( public class ConnectionController(
AppDatabase db, AppDatabase db,
IEnumerable<OidcService> oidcServices, IEnumerable<OidcService> oidcServices,
AccountService accountService, AccountService accounts,
AuthService authService, AuthService auth
IClock clock
) : ControllerBase ) : ControllerBase
{ {
[HttpGet] [HttpGet]
@ -25,7 +24,7 @@ public class ConnectionController(
var connections = await db.AccountConnections var connections = await db.AccountConnections
.Where(c => c.AccountId == currentUser.Id) .Where(c => c.AccountId == currentUser.Id)
.Select(c => new { c.Id, c.AccountId, c.Provider, c.ProvidedIdentifier }) .Select(c => new { c.Id, c.AccountId, c.Provider, c.ProvidedIdentifier, c.Meta, c.LastUsedAt })
.ToListAsync(); .ToListAsync();
return Ok(connections); return Ok(connections);
} }
@ -63,7 +62,8 @@ public class ConnectionController(
if (HttpContext.Items["CurrentUser"] is not Account.Account currentUser) if (HttpContext.Items["CurrentUser"] is not Account.Account currentUser)
return Unauthorized(); return Unauthorized();
var oidcService = oidcServices.FirstOrDefault(s => s.ProviderName.Equals(request.Provider, StringComparison.OrdinalIgnoreCase)); var oidcService = oidcServices.FirstOrDefault(s =>
s.ProviderName.Equals(request.Provider, StringComparison.OrdinalIgnoreCase));
if (oidcService == null) if (oidcService == null)
return BadRequest($"Provider '{request.Provider}' is not supported"); return BadRequest($"Provider '{request.Provider}' is not supported");
@ -94,7 +94,8 @@ public class ConnectionController(
[HttpGet, HttpPost] [HttpGet, HttpPost]
public async Task<IActionResult> HandleCallback([FromRoute] string provider) public async Task<IActionResult> HandleCallback([FromRoute] string provider)
{ {
var oidcService = oidcServices.FirstOrDefault(s => s.ProviderName.Equals(provider, StringComparison.OrdinalIgnoreCase)); var oidcService =
oidcServices.FirstOrDefault(s => s.ProviderName.Equals(provider, StringComparison.OrdinalIgnoreCase));
if (oidcService == null) if (oidcService == null)
return BadRequest($"Provider '{provider}' is not supported."); return BadRequest($"Provider '{provider}' is not supported.");
@ -106,21 +107,19 @@ public class ConnectionController(
HttpContext.Session.Remove($"oidc_state_{callbackData.State}"); HttpContext.Session.Remove($"oidc_state_{callbackData.State}");
// If sessionState is present, it's a manual connection flow for an existing user. // If sessionState is present, it's a manual connection flow for an existing user.
if (sessionState != null) if (sessionState == null) return await HandleLoginOrRegistration(provider, oidcService, callbackData);
{ var stateParts = sessionState.Split('|');
var stateParts = sessionState.Split('|'); if (stateParts.Length != 3 || !stateParts[1].Equals(provider, StringComparison.OrdinalIgnoreCase))
if (stateParts.Length != 3 || !stateParts[1].Equals(provider, StringComparison.OrdinalIgnoreCase)) return BadRequest("State mismatch.");
return BadRequest("State mismatch.");
var accountId = Guid.Parse(stateParts[0]);
return await HandleManualConnection(provider, oidcService, callbackData, accountId);
var accountId = Guid.Parse(stateParts[0]);
return await HandleManualConnection(provider, oidcService, callbackData, accountId);
}
// Otherwise, it's a login or registration flow. // Otherwise, it's a login or registration flow.
return await HandleLoginOrRegistration(provider, oidcService, callbackData);
} }
private async Task<IActionResult> HandleManualConnection(string provider, OidcService oidcService, OidcCallbackData callbackData, Guid accountId) private async Task<IActionResult> HandleManualConnection(string provider, OidcService oidcService,
OidcCallbackData callbackData, Guid accountId)
{ {
OidcUserInfo userInfo; OidcUserInfo userInfo;
try try
@ -133,7 +132,9 @@ public class ConnectionController(
} }
var existingConnection = await db.AccountConnections var existingConnection = await db.AccountConnections
.FirstOrDefaultAsync(c => c.Provider.Equals(provider, StringComparison.OrdinalIgnoreCase) && c.ProvidedIdentifier == userInfo.UserId); .FirstOrDefaultAsync(c =>
c.Provider.Equals(provider, StringComparison.OrdinalIgnoreCase) &&
c.ProvidedIdentifier == userInfo.UserId);
if (existingConnection != null && existingConnection.AccountId != accountId) if (existingConnection != null && existingConnection.AccountId != accountId)
{ {
@ -141,8 +142,10 @@ public class ConnectionController(
} }
var userConnection = await db.AccountConnections var userConnection = await db.AccountConnections
.FirstOrDefaultAsync(c => c.AccountId == accountId && c.Provider.Equals(provider, StringComparison.OrdinalIgnoreCase)); .FirstOrDefaultAsync(c =>
c.AccountId == accountId && c.Provider.Equals(provider, StringComparison.OrdinalIgnoreCase));
var clock = SystemClock.Instance;
if (userConnection != null) if (userConnection != null)
{ {
userConnection.AccessToken = userInfo.AccessToken; userConnection.AccessToken = userInfo.AccessToken;
@ -158,7 +161,8 @@ public class ConnectionController(
ProvidedIdentifier = userInfo.UserId!, ProvidedIdentifier = userInfo.UserId!,
AccessToken = userInfo.AccessToken, AccessToken = userInfo.AccessToken,
RefreshToken = userInfo.RefreshToken, RefreshToken = userInfo.RefreshToken,
LastUsedAt = clock.GetCurrentInstant() LastUsedAt = clock.GetCurrentInstant(),
Meta = userInfo.ToMetadata(),
}); });
} }
@ -170,7 +174,11 @@ public class ConnectionController(
return Redirect(string.IsNullOrEmpty(returnUrl) ? "/" : returnUrl); return Redirect(string.IsNullOrEmpty(returnUrl) ? "/" : returnUrl);
} }
private async Task<IActionResult> HandleLoginOrRegistration(string provider, OidcService oidcService, OidcCallbackData callbackData) private async Task<IActionResult> HandleLoginOrRegistration(
string provider,
OidcService oidcService,
OidcCallbackData callbackData
)
{ {
OidcUserInfo userInfo; OidcUserInfo userInfo;
try try
@ -191,25 +199,17 @@ public class ConnectionController(
.Include(c => c.Account) .Include(c => c.Account)
.FirstOrDefaultAsync(c => c.Provider == provider && c.ProvidedIdentifier == userInfo.UserId); .FirstOrDefaultAsync(c => c.Provider == provider && c.ProvidedIdentifier == userInfo.UserId);
var clock = SystemClock.Instance;
if (connection != null) if (connection != null)
{ {
// Login existing user // Login existing user
var session = await authService.CreateSessionAsync(connection.Account, clock.GetCurrentInstant()); var session = await auth.CreateSessionAsync(connection.Account, clock.GetCurrentInstant());
var token = authService.CreateToken(session); var token = auth.CreateToken(session);
return Redirect($"/?token={token}"); return Redirect($"/?token={token}");
} }
var account = await accountService.LookupAccount(userInfo.Email); // Register new user
if (account == null) var account = await accounts.LookupAccount(userInfo.Email) ?? await accounts.CreateAccount(userInfo);
{
// Register new user
account = await accountService.CreateAccount(userInfo);
}
if (account == null)
{
return BadRequest("Unable to create or link account.");
}
// Create connection for new or existing user // Create connection for new or existing user
var newConnection = new AccountConnection var newConnection = new AccountConnection
@ -219,40 +219,43 @@ public class ConnectionController(
ProvidedIdentifier = userInfo.UserId!, ProvidedIdentifier = userInfo.UserId!,
AccessToken = userInfo.AccessToken, AccessToken = userInfo.AccessToken,
RefreshToken = userInfo.RefreshToken, RefreshToken = userInfo.RefreshToken,
LastUsedAt = clock.GetCurrentInstant() LastUsedAt = clock.GetCurrentInstant(),
Meta = userInfo.ToMetadata()
}; };
db.AccountConnections.Add(newConnection); db.AccountConnections.Add(newConnection);
await db.SaveChangesAsync(); await db.SaveChangesAsync();
var loginSession = await authService.CreateSessionAsync(account, clock.GetCurrentInstant()); var loginSession = await auth.CreateSessionAsync(account, clock.GetCurrentInstant());
var loginToken = authService.CreateToken(loginSession); var loginToken = auth.CreateToken(loginSession);
return Redirect($"/?token={loginToken}"); return Redirect($"/?token={loginToken}");
} }
private async Task<OidcCallbackData> ExtractCallbackData(HttpRequest request) private static async Task<OidcCallbackData> ExtractCallbackData(HttpRequest request)
{ {
var data = new OidcCallbackData(); var data = new OidcCallbackData();
if (request.Method == "GET") switch (request.Method)
{ {
data.Code = request.Query["code"].FirstOrDefault() ?? ""; case "GET":
data.IdToken = request.Query["id_token"].FirstOrDefault() ?? ""; data.Code = request.Query["code"].FirstOrDefault() ?? "";
data.State = request.Query["state"].FirstOrDefault(); data.IdToken = request.Query["id_token"].FirstOrDefault() ?? "";
} data.State = request.Query["state"].FirstOrDefault();
else if (request.Method == "POST" && request.HasFormContentType) break;
{ case "POST" when request.HasFormContentType:
var form = await request.ReadFormAsync();
data.Code = form["code"].FirstOrDefault() ?? "";
data.IdToken = form["id_token"].FirstOrDefault() ?? "";
data.State = form["state"].FirstOrDefault();
if (form.ContainsKey("user"))
{ {
data.RawData = form["user"].FirstOrDefault(); var form = await request.ReadFormAsync();
data.Code = form["code"].FirstOrDefault() ?? "";
data.IdToken = form["id_token"].FirstOrDefault() ?? "";
data.State = form["state"].FirstOrDefault();
if (form.ContainsKey("user"))
{
data.RawData = form["user"].FirstOrDefault();
}
break;
} }
} }
return data; return data;
} }
} }

View File

@ -105,7 +105,7 @@ public class OidcController(
if (string.IsNullOrEmpty(userInfo.Email)) if (string.IsNullOrEmpty(userInfo.Email))
throw new ArgumentException("Email is required for account creation"); throw new ArgumentException("Email is required for account creation");
// Check if account exists by email // Check if an account exists by email
var existingAccount = await accounts.LookupAccount(userInfo.Email); var existingAccount = await accounts.LookupAccount(userInfo.Email);
if (existingAccount != null) if (existingAccount != null)
{ {
@ -116,12 +116,28 @@ public class OidcController(
c.ProvidedIdentifier == userInfo.UserId); c.ProvidedIdentifier == userInfo.UserId);
// If no connection exists, create one // If no connection exists, create one
if (existingConnection != null) return existingAccount; if (existingConnection != null)
{
await db.AccountConnections
.Where(c => c.AccountId == existingAccount.Id &&
c.Provider == provider &&
c.ProvidedIdentifier == userInfo.UserId)
.ExecuteUpdateAsync(s => s
.SetProperty(c => c.LastUsedAt, SystemClock.Instance.GetCurrentInstant())
.SetProperty(c => c.Meta, userInfo.ToMetadata()));
return existingAccount;
}
var connection = new AccountConnection var connection = new AccountConnection
{ {
AccountId = existingAccount.Id, AccountId = existingAccount.Id,
Provider = provider, Provider = provider,
ProvidedIdentifier = userInfo.UserId!, ProvidedIdentifier = userInfo.UserId!,
AccessToken = userInfo.AccessToken,
RefreshToken = userInfo.RefreshToken,
LastUsedAt = SystemClock.Instance.GetCurrentInstant(),
Meta = userInfo.ToMetadata()
}; };
db.AccountConnections.Add(connection); db.AccountConnections.Add(connection);
@ -139,6 +155,10 @@ public class OidcController(
AccountId = newAccount.Id, AccountId = newAccount.Id,
Provider = provider, Provider = provider,
ProvidedIdentifier = userInfo.UserId!, ProvidedIdentifier = userInfo.UserId!,
AccessToken = userInfo.AccessToken,
RefreshToken = userInfo.RefreshToken,
LastUsedAt = SystemClock.Instance.GetCurrentInstant(),
Meta = userInfo.ToMetadata()
}; };
db.AccountConnections.Add(newConnection); db.AccountConnections.Add(newConnection);

View File

@ -16,4 +16,34 @@ public class OidcUserInfo
public string Provider { get; set; } = ""; public string Provider { get; set; } = "";
public string? RefreshToken { get; set; } public string? RefreshToken { get; set; }
public string? AccessToken { get; set; } public string? AccessToken { get; set; }
public Dictionary<string, object> ToMetadata()
{
var metadata = new Dictionary<string, object>();
if (!string.IsNullOrWhiteSpace(UserId))
metadata["user_id"] = UserId;
if (!string.IsNullOrWhiteSpace(Email))
metadata["email"] = Email;
metadata["email_verified"] = EmailVerified;
if (!string.IsNullOrWhiteSpace(FirstName))
metadata["first_name"] = FirstName;
if (!string.IsNullOrWhiteSpace(LastName))
metadata["last_name"] = LastName;
if (!string.IsNullOrWhiteSpace(DisplayName))
metadata["display_name"] = DisplayName;
if (!string.IsNullOrWhiteSpace(PreferredUsername))
metadata["preferred_username"] = PreferredUsername;
if (!string.IsNullOrWhiteSpace(ProfilePictureUrl))
metadata["profile_picture_url"] = ProfilePictureUrl;
return metadata;
}
} }

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,29 @@
using System.Collections.Generic;
using Microsoft.EntityFrameworkCore.Migrations;
#nullable disable
namespace DysonNetwork.Sphere.Migrations
{
/// <inheritdoc />
public partial class EnrichAccountConnection : Migration
{
/// <inheritdoc />
protected override void Up(MigrationBuilder migrationBuilder)
{
migrationBuilder.AddColumn<Dictionary<string, object>>(
name: "meta",
table: "account_connections",
type: "jsonb",
nullable: true);
}
/// <inheritdoc />
protected override void Down(MigrationBuilder migrationBuilder)
{
migrationBuilder.DropColumn(
name: "meta",
table: "account_connections");
}
}
}

View File

@ -172,6 +172,10 @@ namespace DysonNetwork.Sphere.Migrations
.HasColumnType("timestamp with time zone") .HasColumnType("timestamp with time zone")
.HasColumnName("last_used_at"); .HasColumnName("last_used_at");
b.Property<Dictionary<string, object>>("Meta")
.HasColumnType("jsonb")
.HasColumnName("meta");
b.Property<string>("ProvidedIdentifier") b.Property<string>("ProvidedIdentifier")
.IsRequired() .IsRequired()
.HasMaxLength(8192) .HasMaxLength(8192)