✨ Basis perm nodes feature
This commit is contained in:
		@@ -2,6 +2,8 @@ package services
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/spf13/viper"
 | 
			
		||||
	"gorm.io/datatypes"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.solsynth.dev/hydrogen/passport/pkg/database"
 | 
			
		||||
@@ -66,7 +68,7 @@ func CreateAccount(name, nick, email, password string) (models.Account, error) {
 | 
			
		||||
				VerifiedAt: nil,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		PowerLevel:  0,
 | 
			
		||||
		PermNodes:   datatypes.JSONMap(viper.GetStringMap("permissions.default")),
 | 
			
		||||
		ConfirmedAt: nil,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -98,7 +100,14 @@ func ConfirmAccount(code string) error {
 | 
			
		||||
 | 
			
		||||
	return database.C.Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		user.ConfirmedAt = lo.ToPtr(time.Now())
 | 
			
		||||
		user.PowerLevel += 5
 | 
			
		||||
 | 
			
		||||
		for k, v := range viper.GetStringMap("permissions.verified") {
 | 
			
		||||
			if val, ok := user.PermNodes[k]; !ok {
 | 
			
		||||
				user.PermNodes[k] = v
 | 
			
		||||
			} else if !HasPermNode(val, v) {
 | 
			
		||||
				user.PermNodes[k] = v
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if err := database.C.Delete(&token).Error; err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
 
 | 
			
		||||
@@ -14,7 +14,7 @@ import (
 | 
			
		||||
 | 
			
		||||
const authContextBucket = "AuthContext"
 | 
			
		||||
 | 
			
		||||
func Authenticate(access, refresh string, depth int) (user models.Account, newAccess, newRefresh string, err error) {
 | 
			
		||||
func Authenticate(access, refresh string, depth int) (user models.Account, perms map[string]any, newAccess, newRefresh string, err error) {
 | 
			
		||||
	var claims PayloadClaims
 | 
			
		||||
	claims, err = DecodeJwt(access)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -37,6 +37,7 @@ func Authenticate(access, refresh string, depth int) (user models.Account, newAc
 | 
			
		||||
	ctx, lookupErr := GetAuthContext(claims.ID)
 | 
			
		||||
	if lookupErr == nil {
 | 
			
		||||
		log.Debug().Str("jti", claims.ID).Msg("Hit auth context cache once!")
 | 
			
		||||
		perms = FilterPermNodes(ctx.Account.PermNodes, ctx.Ticket.Claims)
 | 
			
		||||
		user = ctx.Account
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
@@ -44,6 +45,7 @@ func Authenticate(access, refresh string, depth int) (user models.Account, newAc
 | 
			
		||||
	ctx, err = GrantAuthContext(claims.ID)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		log.Debug().Str("jti", claims.ID).Err(lookupErr).Msg("Missed auth context cache once!")
 | 
			
		||||
		perms = FilterPermNodes(ctx.Account.PermNodes, ctx.Ticket.Claims)
 | 
			
		||||
		user = ctx.Account
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
@@ -97,7 +99,7 @@ func GrantAuthContext(jti string) (models.AuthContext, error) {
 | 
			
		||||
		return ctx, fmt.Errorf("invalid account: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Every context should expires in some while
 | 
			
		||||
	// Every context should expire in some while
 | 
			
		||||
	// Once user update their account info, this will have delay to update
 | 
			
		||||
	ctx = models.AuthContext{
 | 
			
		||||
		Ticket:    ticket,
 | 
			
		||||
 
 | 
			
		||||
@@ -15,7 +15,7 @@ type kexRequest struct {
 | 
			
		||||
 | 
			
		||||
var kexRequests = make(map[string]map[string]kexRequest)
 | 
			
		||||
 | 
			
		||||
func KexRequest(conn *websocket.Conn, requestId, keypairId string, ownerId uint, deadline int64) {
 | 
			
		||||
func KexRequest(conn *websocket.Conn, requestId, keypairId, algorithm string, ownerId uint, deadline int64) {
 | 
			
		||||
	if kexRequests[keypairId] == nil {
 | 
			
		||||
		kexRequests[keypairId] = make(map[string]kexRequest)
 | 
			
		||||
	}
 | 
			
		||||
@@ -38,6 +38,7 @@ func KexRequest(conn *websocket.Conn, requestId, keypairId string, ownerId uint,
 | 
			
		||||
			Payload: fiber.Map{
 | 
			
		||||
				"request_id": requestId,
 | 
			
		||||
				"keypair_id": keypairId,
 | 
			
		||||
				"algorithm":  algorithm,
 | 
			
		||||
				"owner_id":   ownerId,
 | 
			
		||||
				"deadline":   deadline,
 | 
			
		||||
			},
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										55
									
								
								pkg/services/perms.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								pkg/services/perms.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,55 @@
 | 
			
		||||
package services
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func HasPermNode(held any, required any) bool {
 | 
			
		||||
	heldValue := reflect.ValueOf(held)
 | 
			
		||||
	requiredValue := reflect.ValueOf(required)
 | 
			
		||||
 | 
			
		||||
	switch heldValue.Kind() {
 | 
			
		||||
	case reflect.Int, reflect.Float64:
 | 
			
		||||
		if heldValue.Float() >= requiredValue.Float() {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
	case reflect.String:
 | 
			
		||||
		if heldValue.String() == requiredValue.String() {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
	case reflect.Slice, reflect.Array:
 | 
			
		||||
		for i := 0; i < heldValue.Len(); i++ {
 | 
			
		||||
			if reflect.DeepEqual(heldValue.Index(i).Interface(), required) {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
		if reflect.DeepEqual(held, required) {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func FilterPermNodes(tree map[string]any, claims []string) map[string]any {
 | 
			
		||||
	filteredTree := make(map[string]any)
 | 
			
		||||
 | 
			
		||||
	match := func(claim, permission string) bool {
 | 
			
		||||
		regex := strings.Replace(permission, "*", ".*", -1)
 | 
			
		||||
		match, _ := regexp.MatchString("^"+regex+"$", claim)
 | 
			
		||||
		return match
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, claim := range claims {
 | 
			
		||||
		for key, value := range tree {
 | 
			
		||||
			if match(claim, key) {
 | 
			
		||||
				filteredTree[key] = value
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return filteredTree
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user