Passport/pkg/internal/services/perms.go

89 lines
2.2 KiB
Go
Raw Normal View History

2024-05-17 09:13:11 +00:00
package services
import (
"fmt"
2024-05-17 09:13:11 +00:00
"reflect"
"regexp"
"strings"
)
2024-05-17 11:24:14 +00:00
func HasPermNode(perms map[string]any, requiredKey string, requiredValue any) bool {
if heldValue, ok := perms[requiredKey]; ok {
return ComparePermNode(heldValue, requiredValue)
}
return false
}
func HasPermNodeWithDefault(perms map[string]any, requiredKey string, requiredValue any, defaultValue any) bool {
if heldValue, ok := perms[requiredKey]; ok {
return ComparePermNode(heldValue, requiredValue)
}
return ComparePermNode(defaultValue, requiredValue)
}
2024-05-17 11:24:14 +00:00
func ComparePermNode(held any, required any) bool {
2024-08-25 10:36:19 +00:00
isNumeric := func(val reflect.Value) bool {
kind := val.Kind()
return kind >= reflect.Int && kind <= reflect.Uint64 || kind >= reflect.Float32 && kind <= reflect.Float64
}
toFloat64 := func(val reflect.Value) float64 {
switch val.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return float64(val.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return float64(val.Uint())
case reflect.Float32, reflect.Float64:
return val.Float()
default:
panic(fmt.Sprintf("non-numeric value of kind %s", val.Kind()))
}
}
2024-05-17 09:13:11 +00:00
heldValue := reflect.ValueOf(held)
requiredValue := reflect.ValueOf(required)
switch heldValue.Kind() {
case reflect.String:
if heldValue.String() == requiredValue.String() {
return true
}
case reflect.Slice, reflect.Array:
for i := 0; i < heldValue.Len(); i++ {
if reflect.DeepEqual(heldValue.Index(i).Interface(), required) {
return true
}
}
default:
2024-08-25 10:36:19 +00:00
if isNumeric(heldValue) && isNumeric(requiredValue) {
return toFloat64(heldValue) >= toFloat64(requiredValue)
}
2024-05-17 09:13:11 +00:00
if reflect.DeepEqual(held, required) {
return true
}
}
return false
}
func FilterPermNodes(tree map[string]any, claims []string) map[string]any {
filteredTree := make(map[string]any)
match := func(claim, permission string) bool {
regex := strings.ReplaceAll(claim, "*", ".*")
match, _ := regexp.MatchString(fmt.Sprintf("^%s$", regex), permission)
2024-05-17 09:13:11 +00:00
return match
}
for _, claim := range claims {
for key, value := range tree {
if match(claim, key) {
filteredTree[key] = value
}
}
}
return filteredTree
}