Files
Swarm/DysonNetwork.Common/Services/Permission/PermissionMiddleware.cs

101 lines
3.3 KiB
C#

using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
using System;
using System.Linq;
using System.Threading.Tasks;
using DysonNetwork.Common.Extensions;
using Microsoft.AspNetCore.Http.Extensions;
namespace DysonNetwork.Common.Services.Permission;
[AttributeUsage(AttributeTargets.Method, Inherited = true)]
public class RequiredPermissionAttribute(string area, string key) : Attribute
{
public string Area { get; set; } = area;
public string Key { get; } = key;
}
public class PermissionMiddleware<TDbContext> where TDbContext : DbContext
{
private readonly RequestDelegate _next;
private readonly IServiceProvider _serviceProvider;
public PermissionMiddleware(RequestDelegate next, IServiceProvider serviceProvider)
{
_next = next;
_serviceProvider = serviceProvider;
}
public async Task InvokeAsync(HttpContext httpContext)
{
using var scope = _serviceProvider.CreateScope();
var permissionService = new PermissionService<TDbContext>(
scope.ServiceProvider.GetRequiredService<TDbContext>(),
scope.ServiceProvider.GetRequiredService<ICacheService>()
);
var endpoint = httpContext.GetEndpoint();
var attr = endpoint?.Metadata.OfType<RequiredPermissionAttribute>().FirstOrDefault();
if (attr != null)
{
if (httpContext.User.Identity?.IsAuthenticated != true)
{
httpContext.Response.StatusCode = StatusCodes.Status403Forbidden;
await httpContext.Response.WriteAsync("Unauthorized");
return;
}
var currentUserId = httpContext.User.GetUserId();
if (currentUserId == Guid.Empty)
{
httpContext.Response.StatusCode = StatusCodes.Status403Forbidden;
await httpContext.Response.WriteAsync("Unauthorized");
return;
}
// TODO: Check for superuser from PassClient
// if (currentUser.IsSuperuser)
// {
// await _next(httpContext);
// return;
// }
var actor = $"user:{currentUserId}";
var hasPermission = await permissionService.HasPermissionAsync(actor, attr.Area, attr.Key);
if (!hasPermission)
{
httpContext.Response.StatusCode = StatusCodes.Status403Forbidden;
await httpContext.Response.WriteAsync("Forbidden");
return;
}
}
await _next.Invoke(httpContext);
}
}
public static class PermissionServiceExtensions
{
public static IServiceCollection AddPermissionService<TDbContext>(this IServiceCollection services)
where TDbContext : DbContext
{
services.AddScoped<PermissionService<TDbContext>>(sp =>
new PermissionService<TDbContext>(
sp.GetRequiredService<TDbContext>(),
sp.GetRequiredService<ICacheService>()
));
return services;
}
public static IApplicationBuilder UsePermissionMiddleware<TDbContext>(this IApplicationBuilder builder)
where TDbContext : DbContext
{
return builder.UseMiddleware<PermissionMiddleware<TDbContext>>(builder.ApplicationServices);
}
}