🎨 Update project structure
This commit is contained in:
123
pkg/internal/services/accounts.go
Normal file
123
pkg/internal/services/accounts.go
Normal file
@ -0,0 +1,123 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/spf13/viper"
|
||||
"gorm.io/datatypes"
|
||||
"time"
|
||||
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/database"
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/models"
|
||||
"github.com/google/uuid"
|
||||
"github.com/samber/lo"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func GetAccount(id uint) (models.Account, error) {
|
||||
var account models.Account
|
||||
if err := database.C.Where(models.Account{
|
||||
BaseModel: models.BaseModel{ID: id},
|
||||
}).First(&account).Error; err != nil {
|
||||
return account, err
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func LookupAccount(probe string) (models.Account, error) {
|
||||
var account models.Account
|
||||
if err := database.C.Where(models.Account{Name: probe}).First(&account).Error; err == nil {
|
||||
return account, nil
|
||||
}
|
||||
|
||||
var contact models.AccountContact
|
||||
if err := database.C.Where(models.AccountContact{Content: probe}).First(&contact).Error; err == nil {
|
||||
if err := database.C.
|
||||
Where(models.Account{
|
||||
BaseModel: models.BaseModel{ID: contact.AccountID},
|
||||
}).First(&account).Error; err == nil {
|
||||
return account, err
|
||||
}
|
||||
}
|
||||
|
||||
return account, fmt.Errorf("account was not found")
|
||||
}
|
||||
|
||||
func CreateAccount(name, nick, email, password string) (models.Account, error) {
|
||||
user := models.Account{
|
||||
Name: name,
|
||||
Nick: nick,
|
||||
Profile: models.AccountProfile{
|
||||
Experience: 100,
|
||||
},
|
||||
Factors: []models.AuthFactor{
|
||||
{
|
||||
Type: models.PasswordAuthFactor,
|
||||
Secret: HashPassword(password),
|
||||
},
|
||||
{
|
||||
Type: models.EmailPasswordFactor,
|
||||
Secret: uuid.NewString()[:8],
|
||||
},
|
||||
},
|
||||
Contacts: []models.AccountContact{
|
||||
{
|
||||
Type: models.EmailAccountContact,
|
||||
Content: email,
|
||||
IsPrimary: true,
|
||||
VerifiedAt: nil,
|
||||
},
|
||||
},
|
||||
PermNodes: datatypes.JSONMap(viper.GetStringMap("permissions.default")),
|
||||
ConfirmedAt: nil,
|
||||
}
|
||||
|
||||
if err := database.C.Create(&user).Error; err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
if tk, err := NewMagicToken(models.ConfirmMagicToken, &user, nil); err != nil {
|
||||
return user, err
|
||||
} else if err := NotifyMagicToken(tk); err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func ConfirmAccount(code string) error {
|
||||
token, err := ValidateMagicToken(code, models.ConfirmMagicToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var user models.Account
|
||||
if err := database.C.Where(&models.Account{
|
||||
BaseModel: models.BaseModel{ID: *token.AssignTo},
|
||||
}).First(&user).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return database.C.Transaction(func(tx *gorm.DB) error {
|
||||
user.ConfirmedAt = lo.ToPtr(time.Now())
|
||||
|
||||
for k, v := range viper.GetStringMap("permissions.verified") {
|
||||
if val, ok := user.PermNodes[k]; !ok {
|
||||
user.PermNodes[k] = v
|
||||
} else if !ComparePermNode(val, v) {
|
||||
user.PermNodes[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
if err := database.C.Delete(&token).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := database.C.Save(&user).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
InvalidAuthCacheWithUser(user.ID)
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
125
pkg/internal/services/auth.go
Normal file
125
pkg/internal/services/auth.go
Normal file
@ -0,0 +1,125 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
jsoniter "github.com/json-iterator/go"
|
||||
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/models"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
var (
|
||||
authContextMutex sync.Mutex
|
||||
authContextCache = make(map[string]models.AuthContext)
|
||||
)
|
||||
|
||||
func Authenticate(access, refresh string, depth int) (ctx models.AuthContext, perms map[string]any, newAccess, newRefresh string, err error) {
|
||||
var claims PayloadClaims
|
||||
claims, err = DecodeJwt(access)
|
||||
if err != nil {
|
||||
if len(refresh) > 0 && depth < 1 {
|
||||
// Auto refresh and retry
|
||||
newAccess, newRefresh, err = RefreshToken(refresh)
|
||||
if err == nil {
|
||||
return Authenticate(newAccess, newRefresh, depth+1)
|
||||
}
|
||||
}
|
||||
err = fiber.NewError(fiber.StatusUnauthorized, fmt.Sprintf("invalid auth key: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
newAccess = access
|
||||
newRefresh = refresh
|
||||
|
||||
if ctx, err = GetAuthContext(claims.ID); err == nil {
|
||||
var heldPerms map[string]any
|
||||
rawHeldPerms, _ := jsoniter.Marshal(ctx.Account.PermNodes)
|
||||
_ = jsoniter.Unmarshal(rawHeldPerms, &heldPerms)
|
||||
|
||||
perms = FilterPermNodes(heldPerms, ctx.Ticket.Claims)
|
||||
return
|
||||
}
|
||||
|
||||
err = fiber.NewError(fiber.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
func GetAuthContext(jti string) (models.AuthContext, error) {
|
||||
var err error
|
||||
var ctx models.AuthContext
|
||||
|
||||
if val, ok := authContextCache[jti]; ok {
|
||||
ctx = val
|
||||
ctx.LastUsedAt = time.Now()
|
||||
authContextMutex.Lock()
|
||||
authContextCache[jti] = ctx
|
||||
authContextMutex.Unlock()
|
||||
log.Debug().Str("jti", jti).Msg("Used an auth context cache")
|
||||
} else {
|
||||
ctx, err = CacheAuthContext(jti)
|
||||
log.Debug().Str("jti", jti).Msg("Created a new auth context cache")
|
||||
}
|
||||
|
||||
return ctx, err
|
||||
}
|
||||
|
||||
func CacheAuthContext(jti string) (models.AuthContext, error) {
|
||||
var ctx models.AuthContext
|
||||
|
||||
// Query data from primary database
|
||||
ticket, err := GetTicketWithToken(jti)
|
||||
if err != nil {
|
||||
return ctx, fmt.Errorf("invalid auth ticket: %v", err)
|
||||
} else if err := ticket.IsAvailable(); err != nil {
|
||||
return ctx, fmt.Errorf("unavailable auth ticket: %v", err)
|
||||
}
|
||||
|
||||
user, err := GetAccount(ticket.AccountID)
|
||||
if err != nil {
|
||||
return ctx, fmt.Errorf("invalid account: %v", err)
|
||||
}
|
||||
|
||||
ctx = models.AuthContext{
|
||||
Ticket: ticket,
|
||||
Account: user,
|
||||
LastUsedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Put the data into memory for cache
|
||||
authContextMutex.Lock()
|
||||
authContextCache[jti] = ctx
|
||||
authContextMutex.Unlock()
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func RecycleAuthContext() {
|
||||
if len(authContextCache) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
affected := 0
|
||||
for key, val := range authContextCache {
|
||||
if val.LastUsedAt.Add(60*time.Second).Unix() < time.Now().Unix() {
|
||||
affected++
|
||||
authContextMutex.Lock()
|
||||
delete(authContextCache, key)
|
||||
authContextMutex.Unlock()
|
||||
}
|
||||
}
|
||||
log.Debug().Int("affected", affected).Msg("Recycled auth context...")
|
||||
}
|
||||
|
||||
func InvalidAuthCacheWithUser(userId uint) {
|
||||
for key, val := range authContextCache {
|
||||
if val.Account.ID == userId {
|
||||
authContextMutex.Lock()
|
||||
delete(authContextCache, key)
|
||||
authContextMutex.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
15
pkg/internal/services/badges.go
Normal file
15
pkg/internal/services/badges.go
Normal file
@ -0,0 +1,15 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/database"
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/models"
|
||||
)
|
||||
|
||||
func GrantBadge(user models.Account, badge models.Badge) error {
|
||||
badge.AccountID = user.ID
|
||||
return database.C.Save(badge).Error
|
||||
}
|
||||
|
||||
func RevokeBadge(badge models.Badge) error {
|
||||
return database.C.Delete(&badge).Error
|
||||
}
|
23
pkg/internal/services/cleaner.go
Normal file
23
pkg/internal/services/cleaner.go
Normal file
@ -0,0 +1,23 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/database"
|
||||
"github.com/rs/zerolog/log"
|
||||
"time"
|
||||
)
|
||||
|
||||
func DoAutoDatabaseCleanup() {
|
||||
deadline := time.Now().Add(60 * time.Minute)
|
||||
log.Debug().Time("deadline", deadline).Msg("Now cleaning up entire database...")
|
||||
|
||||
var count int64
|
||||
for _, model := range database.AutoMaintainRange {
|
||||
tx := database.C.Unscoped().Delete(model, "deleted_at >= ?", deadline)
|
||||
if tx.Error != nil {
|
||||
log.Error().Err(tx.Error).Msg("An error occurred when running auth context cleanup...")
|
||||
}
|
||||
count += tx.RowsAffected
|
||||
}
|
||||
|
||||
log.Debug().Int64("affected", count).Msg("Clean up entire database accomplished.")
|
||||
}
|
32
pkg/internal/services/clients.go
Normal file
32
pkg/internal/services/clients.go
Normal file
@ -0,0 +1,32 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/database"
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/models"
|
||||
)
|
||||
|
||||
func GetThirdClient(id string) (models.ThirdClient, error) {
|
||||
var client models.ThirdClient
|
||||
if err := database.C.Where(&models.ThirdClient{
|
||||
Alias: id,
|
||||
}).First(&client).Error; err != nil {
|
||||
return client, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func GetThirdClientWithSecret(id, secret string) (models.ThirdClient, error) {
|
||||
client, err := GetThirdClient(id)
|
||||
if err != nil {
|
||||
return client, err
|
||||
}
|
||||
|
||||
if client.Secret != secret {
|
||||
return client, fmt.Errorf("invalid client secret")
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
31
pkg/internal/services/connections.go
Normal file
31
pkg/internal/services/connections.go
Normal file
@ -0,0 +1,31 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/models"
|
||||
"github.com/gofiber/contrib/websocket"
|
||||
)
|
||||
|
||||
var (
|
||||
wsMutex sync.Mutex
|
||||
wsConn = make(map[uint]map[*websocket.Conn]bool)
|
||||
)
|
||||
|
||||
func ClientRegister(user models.Account, conn *websocket.Conn) {
|
||||
wsMutex.Lock()
|
||||
if wsConn[user.ID] == nil {
|
||||
wsConn[user.ID] = make(map[*websocket.Conn]bool)
|
||||
}
|
||||
wsConn[user.ID][conn] = true
|
||||
wsMutex.Unlock()
|
||||
}
|
||||
|
||||
func ClientUnregister(user models.Account, conn *websocket.Conn) {
|
||||
wsMutex.Lock()
|
||||
if wsConn[user.ID] == nil {
|
||||
wsConn[user.ID] = make(map[*websocket.Conn]bool)
|
||||
}
|
||||
delete(wsConn[user.ID], conn)
|
||||
wsMutex.Unlock()
|
||||
}
|
82
pkg/internal/services/e2ee.go
Normal file
82
pkg/internal/services/e2ee.go
Normal file
@ -0,0 +1,82 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/models"
|
||||
"github.com/gofiber/contrib/websocket"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"time"
|
||||
)
|
||||
|
||||
type kexRequest struct {
|
||||
OwnerID uint
|
||||
Conn *websocket.Conn
|
||||
Deadline time.Time
|
||||
}
|
||||
|
||||
var kexRequests = make(map[string]map[string]kexRequest)
|
||||
|
||||
func KexRequest(conn *websocket.Conn, requestId, keypairId, algorithm string, ownerId uint, deadline int64) {
|
||||
if kexRequests[keypairId] == nil {
|
||||
kexRequests[keypairId] = make(map[string]kexRequest)
|
||||
}
|
||||
|
||||
ddl := time.Now().Add(time.Second * time.Duration(deadline))
|
||||
request := kexRequest{
|
||||
OwnerID: ownerId,
|
||||
Conn: conn,
|
||||
Deadline: ddl,
|
||||
}
|
||||
|
||||
flag := false
|
||||
for c := range wsConn[ownerId] {
|
||||
if c == conn {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.WriteMessage(1, models.UnifiedCommand{
|
||||
Action: "kex.request",
|
||||
Payload: fiber.Map{
|
||||
"request_id": requestId,
|
||||
"keypair_id": keypairId,
|
||||
"algorithm": algorithm,
|
||||
"owner_id": ownerId,
|
||||
"deadline": deadline,
|
||||
},
|
||||
}.Marshal()) == nil {
|
||||
flag = true
|
||||
}
|
||||
}
|
||||
|
||||
if flag {
|
||||
kexRequests[keypairId][requestId] = request
|
||||
}
|
||||
}
|
||||
|
||||
func KexProvide(userId uint, requestId, keypairId string, pkt []byte) {
|
||||
if kexRequests[keypairId] == nil {
|
||||
return
|
||||
}
|
||||
|
||||
val, ok := kexRequests[keypairId][requestId]
|
||||
if !ok {
|
||||
return
|
||||
} else if val.OwnerID != userId {
|
||||
return
|
||||
} else {
|
||||
_ = val.Conn.WriteMessage(1, pkt)
|
||||
}
|
||||
}
|
||||
|
||||
func KexCleanup() {
|
||||
if len(kexRequests) <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for kp, data := range kexRequests {
|
||||
for idx, req := range data {
|
||||
if req.Deadline.Unix() <= time.Now().Unix() {
|
||||
delete(kexRequests[kp], idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
12
pkg/internal/services/encryptor.go
Normal file
12
pkg/internal/services/encryptor.go
Normal file
@ -0,0 +1,12 @@
|
||||
package services
|
||||
|
||||
import "golang.org/x/crypto/bcrypt"
|
||||
|
||||
func HashPassword(raw string) string {
|
||||
data, _ := bcrypt.GenerateFromPassword([]byte(raw), 12)
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func VerifyPassword(text string, password string) bool {
|
||||
return bcrypt.CompareHashAndPassword([]byte(password), []byte(text)) == nil
|
||||
}
|
20
pkg/internal/services/events.go
Normal file
20
pkg/internal/services/events.go
Normal file
@ -0,0 +1,20 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/database"
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/models"
|
||||
)
|
||||
|
||||
func AddEvent(user models.Account, event, target, ip, ua string) models.ActionEvent {
|
||||
evt := models.ActionEvent{
|
||||
Type: event,
|
||||
Target: target,
|
||||
IpAddress: ip,
|
||||
UserAgent: ua,
|
||||
AccountID: user.ID,
|
||||
}
|
||||
|
||||
database.C.Save(&evt)
|
||||
|
||||
return evt
|
||||
}
|
25
pkg/internal/services/external_apns.go
Normal file
25
pkg/internal/services/external_apns.go
Normal file
@ -0,0 +1,25 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"github.com/sideshow/apns2"
|
||||
"github.com/sideshow/apns2/token"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// ExtAPNS is Apple Notification Services client
|
||||
var ExtAPNS *apns2.Client
|
||||
|
||||
func SetupAPNS() error {
|
||||
authKey, err := token.AuthKeyFromFile(viper.GetString("apns_credentials"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ExtAPNS = apns2.NewTokenClient(&token.Token{
|
||||
AuthKey: authKey,
|
||||
KeyID: viper.GetString("apns_credentials_key"),
|
||||
TeamID: viper.GetString("apns_credentials_team"),
|
||||
}).Production()
|
||||
|
||||
return nil
|
||||
}
|
23
pkg/internal/services/external_firebase.go
Normal file
23
pkg/internal/services/external_firebase.go
Normal file
@ -0,0 +1,23 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
firebase "firebase.google.com/go"
|
||||
"github.com/spf13/viper"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
// ExtFire is the firebase app client
|
||||
var ExtFire *firebase.App
|
||||
|
||||
func SetupFirebase() error {
|
||||
opt := option.WithCredentialsFile(viper.GetString("firebase_credentials"))
|
||||
app, err := firebase.NewApp(context.Background(), nil, opt)
|
||||
if err != nil {
|
||||
return err
|
||||
} else {
|
||||
ExtFire = app
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
109
pkg/internal/services/factors.go
Normal file
109
pkg/internal/services/factors.go
Normal file
@ -0,0 +1,109 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/samber/lo"
|
||||
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/database"
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/models"
|
||||
"github.com/google/uuid"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
const EmailPasswordTemplate = `Dear %s,
|
||||
|
||||
We hope this message finds you well.
|
||||
As part of our ongoing commitment to ensuring the security of your account, we require you to complete the login process by entering the verification code below:
|
||||
|
||||
Your Login Verification Code: %s
|
||||
|
||||
Please use the provided code within the next 2 hours to complete your login.
|
||||
If you did not request this code, please update your information, maybe your username or email has been leak.
|
||||
|
||||
Thank you for your cooperation in helping us maintain the security of your account.
|
||||
|
||||
Best regards,
|
||||
%s`
|
||||
|
||||
func GetPasswordTypeFactor(userId uint) (models.AuthFactor, error) {
|
||||
var factor models.AuthFactor
|
||||
err := database.C.Where(models.AuthFactor{
|
||||
Type: models.PasswordAuthFactor,
|
||||
AccountID: userId,
|
||||
}).First(&factor).Error
|
||||
|
||||
return factor, err
|
||||
}
|
||||
|
||||
func GetFactor(id uint) (models.AuthFactor, error) {
|
||||
var factor models.AuthFactor
|
||||
err := database.C.Where(models.AuthFactor{
|
||||
BaseModel: models.BaseModel{ID: id},
|
||||
}).First(&factor).Error
|
||||
|
||||
return factor, err
|
||||
}
|
||||
|
||||
func ListUserFactor(userId uint) ([]models.AuthFactor, error) {
|
||||
var factors []models.AuthFactor
|
||||
err := database.C.Where(models.AuthFactor{
|
||||
AccountID: userId,
|
||||
}).Find(&factors).Error
|
||||
|
||||
return factors, err
|
||||
}
|
||||
|
||||
func CountUserFactor(userId uint) int64 {
|
||||
var count int64
|
||||
database.C.Where(models.AuthFactor{
|
||||
AccountID: userId,
|
||||
}).Model(&models.AuthFactor{}).Count(&count)
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
func GetFactorCode(factor models.AuthFactor) (bool, error) {
|
||||
switch factor.Type {
|
||||
case models.EmailPasswordFactor:
|
||||
var user models.Account
|
||||
if err := database.C.Where(&models.Account{
|
||||
BaseModel: models.BaseModel{ID: factor.AccountID},
|
||||
}).Preload("Contacts").First(&user).Error; err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
||||
factor.Secret = uuid.NewString()[:6]
|
||||
if err := database.C.Save(&factor).Error; err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
||||
subject := fmt.Sprintf("[%s] Login verification code", viper.GetString("name"))
|
||||
content := fmt.Sprintf(EmailPasswordTemplate, user.Name, factor.Secret, viper.GetString("maintainer"))
|
||||
if err := SendMail(user.GetPrimaryEmail().Content, subject, content); err != nil {
|
||||
return true, err
|
||||
}
|
||||
return true, nil
|
||||
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func CheckFactor(factor models.AuthFactor, code string) error {
|
||||
switch factor.Type {
|
||||
case models.PasswordAuthFactor:
|
||||
return lo.Ternary(
|
||||
VerifyPassword(code, factor.Secret),
|
||||
nil,
|
||||
fmt.Errorf("invalid password"),
|
||||
)
|
||||
case models.EmailPasswordFactor:
|
||||
return lo.Ternary(
|
||||
code == factor.Secret,
|
||||
nil,
|
||||
fmt.Errorf("invalid verification code"),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
125
pkg/internal/services/friendships.go
Normal file
125
pkg/internal/services/friendships.go
Normal file
@ -0,0 +1,125 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/database"
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/models"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func ListAllFriend(anyside models.Account) ([]models.AccountFriendship, error) {
|
||||
var relationships []models.AccountFriendship
|
||||
if err := database.C.
|
||||
Where("account_id = ? OR related_id = ?", anyside.ID, anyside.ID).
|
||||
Preload("Account").
|
||||
Preload("Related").
|
||||
Find(&relationships).Error; err != nil {
|
||||
return relationships, err
|
||||
}
|
||||
|
||||
return relationships, nil
|
||||
}
|
||||
|
||||
func ListFriend(anyside models.Account, status models.FriendshipStatus) ([]models.AccountFriendship, error) {
|
||||
var relationships []models.AccountFriendship
|
||||
if err := database.C.
|
||||
Where("(account_id = ? OR related_id = ?) AND status = ?", anyside.ID, anyside.ID, status).
|
||||
Preload("Account").
|
||||
Preload("Related").
|
||||
Find(&relationships).Error; err != nil {
|
||||
return relationships, err
|
||||
}
|
||||
|
||||
return relationships, nil
|
||||
}
|
||||
|
||||
func GetFriend(anysideId uint) (models.AccountFriendship, error) {
|
||||
var relationship models.AccountFriendship
|
||||
if err := database.C.
|
||||
Where(&models.AccountFriendship{AccountID: anysideId}).
|
||||
Or(&models.AccountFriendship{RelatedID: anysideId}).
|
||||
Preload("Account").
|
||||
Preload("Related").
|
||||
First(&relationship).Error; err != nil {
|
||||
return relationship, err
|
||||
}
|
||||
|
||||
return relationship, nil
|
||||
}
|
||||
|
||||
func GetFriendWithTwoSides(userId, relatedId uint, noPreload ...bool) (models.AccountFriendship, error) {
|
||||
var tx *gorm.DB
|
||||
if len(noPreload) > 0 && noPreload[0] {
|
||||
tx = database.C
|
||||
} else {
|
||||
tx = database.C.Preload("Account").Preload("Related")
|
||||
}
|
||||
|
||||
var relationship models.AccountFriendship
|
||||
if err := tx.
|
||||
Where(&models.AccountFriendship{AccountID: userId, RelatedID: relatedId}).
|
||||
Or(&models.AccountFriendship{RelatedID: userId, AccountID: relatedId}).
|
||||
First(&relationship).Error; err != nil {
|
||||
return relationship, err
|
||||
}
|
||||
|
||||
return relationship, nil
|
||||
}
|
||||
|
||||
func NewFriend(user models.Account, related models.Account, status models.FriendshipStatus) (models.AccountFriendship, error) {
|
||||
relationship := models.AccountFriendship{
|
||||
AccountID: user.ID,
|
||||
RelatedID: related.ID,
|
||||
Status: status,
|
||||
}
|
||||
|
||||
if user.ID == related.ID {
|
||||
return relationship, fmt.Errorf("you cannot make friendship with yourself")
|
||||
} else if _, err := GetFriendWithTwoSides(user.ID, related.ID, true); err == nil || !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return relationship, fmt.Errorf("you already have a friendship with him or her")
|
||||
}
|
||||
|
||||
if err := database.C.Save(&relationship).Error; err != nil {
|
||||
return relationship, err
|
||||
} else {
|
||||
_ = NewNotification(models.Notification{
|
||||
Subject: fmt.Sprintf("New friend request from %s", user.Name),
|
||||
Content: fmt.Sprintf("You got a new friend request from %s. Go to your settings and decide how to deal it.", user.Nick),
|
||||
RecipientID: related.ID,
|
||||
})
|
||||
}
|
||||
|
||||
return relationship, nil
|
||||
}
|
||||
|
||||
func EditFriendWithCheck(relationship models.AccountFriendship, user models.Account, originalStatus models.FriendshipStatus) (models.AccountFriendship, error) {
|
||||
if relationship.Status != originalStatus {
|
||||
if originalStatus == models.FriendshipBlocked && relationship.BlockedBy != nil && user.ID != *relationship.BlockedBy {
|
||||
return relationship, fmt.Errorf("the friendship has been blocked by the otherside, you cannot modify it status")
|
||||
}
|
||||
if relationship.Status == models.FriendshipPending && relationship.RelatedID != user.ID {
|
||||
return relationship, fmt.Errorf("only related person can accept friendship")
|
||||
}
|
||||
}
|
||||
if originalStatus != models.FriendshipBlocked && relationship.Status == models.FriendshipBlocked {
|
||||
relationship.BlockedBy = &user.ID
|
||||
}
|
||||
|
||||
return EditFriend(relationship)
|
||||
}
|
||||
|
||||
func EditFriend(relationship models.AccountFriendship) (models.AccountFriendship, error) {
|
||||
if err := database.C.Save(&relationship).Error; err != nil {
|
||||
return relationship, err
|
||||
}
|
||||
return relationship, nil
|
||||
}
|
||||
|
||||
func DeleteFriend(relationship models.AccountFriendship) error {
|
||||
if err := database.C.Delete(&relationship).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
81
pkg/internal/services/jwt.go
Normal file
81
pkg/internal/services/jwt.go
Normal file
@ -0,0 +1,81 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var CookieAccessKey = "passport_auth_key"
|
||||
var CookieRefreshKey = "passport_refresh_key"
|
||||
|
||||
type PayloadClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
|
||||
SessionID string `json:"sed"`
|
||||
Type string `json:"typ"`
|
||||
}
|
||||
|
||||
const (
|
||||
JwtAccessType = "access"
|
||||
JwtRefreshType = "refresh"
|
||||
)
|
||||
|
||||
func EncodeJwt(id string, typ, sub, sed string, aud []string, exp time.Time) (string, error) {
|
||||
tk := jwt.NewWithClaims(jwt.SigningMethodHS512, PayloadClaims{
|
||||
jwt.RegisteredClaims{
|
||||
Subject: sub,
|
||||
Audience: aud,
|
||||
Issuer: fmt.Sprintf("https://%s", viper.GetString("domain")),
|
||||
ExpiresAt: jwt.NewNumericDate(exp),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
ID: id,
|
||||
},
|
||||
sed,
|
||||
typ,
|
||||
})
|
||||
|
||||
return tk.SignedString([]byte(viper.GetString("secret")))
|
||||
}
|
||||
|
||||
func DecodeJwt(str string) (PayloadClaims, error) {
|
||||
var claims PayloadClaims
|
||||
tk, err := jwt.ParseWithClaims(str, &claims, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(viper.GetString("secret")), nil
|
||||
})
|
||||
if err != nil {
|
||||
return claims, err
|
||||
}
|
||||
|
||||
if data, ok := tk.Claims.(*PayloadClaims); ok {
|
||||
return *data, nil
|
||||
} else {
|
||||
return claims, fmt.Errorf("unexpected token payload: not payload claims type")
|
||||
}
|
||||
}
|
||||
|
||||
func SetJwtCookieSet(c *fiber.Ctx, access, refresh string) {
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: CookieAccessKey,
|
||||
Value: access,
|
||||
Domain: viper.GetString("security.cookie_domain"),
|
||||
SameSite: viper.GetString("security.cookie_samesite"),
|
||||
Expires: time.Now().Add(60 * time.Minute),
|
||||
Path: "/",
|
||||
})
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: CookieRefreshKey,
|
||||
Value: refresh,
|
||||
Domain: viper.GetString("security.cookie_domain"),
|
||||
SameSite: viper.GetString("security.cookie_samesite"),
|
||||
Expires: time.Now().Add(24 * 30 * time.Hour),
|
||||
Path: "/",
|
||||
})
|
||||
}
|
51
pkg/internal/services/mailer.go
Normal file
51
pkg/internal/services/mailer.go
Normal file
@ -0,0 +1,51 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/smtp"
|
||||
"net/textproto"
|
||||
|
||||
"github.com/jordan-wright/email"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func SendMail(target string, subject string, content string) error {
|
||||
mail := &email.Email{
|
||||
To: []string{target},
|
||||
From: viper.GetString("mailer.name"),
|
||||
Subject: subject,
|
||||
Text: []byte(content),
|
||||
Headers: textproto.MIMEHeader{},
|
||||
}
|
||||
return mail.SendWithTLS(
|
||||
fmt.Sprintf("%s:%d", viper.GetString("mailer.smtp_host"), viper.GetInt("mailer.smtp_port")),
|
||||
smtp.PlainAuth(
|
||||
"",
|
||||
viper.GetString("mailer.username"),
|
||||
viper.GetString("mailer.password"),
|
||||
viper.GetString("mailer.smtp_host"),
|
||||
),
|
||||
&tls.Config{ServerName: viper.GetString("mailer.smtp_host")},
|
||||
)
|
||||
}
|
||||
|
||||
func SendMailHTML(target string, subject string, content string) error {
|
||||
mail := &email.Email{
|
||||
To: []string{target},
|
||||
From: viper.GetString("mailer.name"),
|
||||
Subject: subject,
|
||||
HTML: []byte(content),
|
||||
Headers: textproto.MIMEHeader{},
|
||||
}
|
||||
return mail.SendWithTLS(
|
||||
fmt.Sprintf("%s:%d", viper.GetString("mailer.smtp_host"), viper.GetInt("mailer.smtp_port")),
|
||||
smtp.PlainAuth(
|
||||
"",
|
||||
viper.GetString("mailer.username"),
|
||||
viper.GetString("mailer.password"),
|
||||
viper.GetString("mailer.smtp_host"),
|
||||
),
|
||||
&tls.Config{ServerName: viper.GetString("mailer.smtp_host")},
|
||||
)
|
||||
}
|
18
pkg/internal/services/mfa.go
Normal file
18
pkg/internal/services/mfa.go
Normal file
@ -0,0 +1,18 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/models"
|
||||
"github.com/nicksnyder/go-i18n/v2/i18n"
|
||||
)
|
||||
|
||||
func GetFactorName(w models.AuthFactorType, localizer *i18n.Localizer) string {
|
||||
unknown, _ := localizer.LocalizeMessage(&i18n.Message{ID: "unknown"})
|
||||
mfaEmail, _ := localizer.LocalizeMessage(&i18n.Message{ID: "mfaFactorEmail"})
|
||||
|
||||
switch w {
|
||||
case models.EmailPasswordFactor:
|
||||
return mfaEmail
|
||||
default:
|
||||
return unknown
|
||||
}
|
||||
}
|
138
pkg/internal/services/notifications.go
Normal file
138
pkg/internal/services/notifications.go
Normal file
@ -0,0 +1,138 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"firebase.google.com/go/messaging"
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/database"
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sideshow/apns2"
|
||||
payload2 "github.com/sideshow/apns2/payload"
|
||||
"github.com/spf13/viper"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func AddNotifySubscriber(user models.Account, provider, id, tk, ua string) (models.NotificationSubscriber, error) {
|
||||
var prev models.NotificationSubscriber
|
||||
var subscriber models.NotificationSubscriber
|
||||
if err := database.C.Where(&models.NotificationSubscriber{
|
||||
DeviceID: id,
|
||||
AccountID: user.ID,
|
||||
}); err != nil {
|
||||
subscriber = models.NotificationSubscriber{
|
||||
UserAgent: ua,
|
||||
Provider: provider,
|
||||
DeviceID: id,
|
||||
DeviceToken: tk,
|
||||
AccountID: user.ID,
|
||||
}
|
||||
} else {
|
||||
prev = subscriber
|
||||
}
|
||||
|
||||
subscriber.UserAgent = ua
|
||||
subscriber.Provider = provider
|
||||
subscriber.DeviceToken = tk
|
||||
|
||||
var err error
|
||||
if !reflect.DeepEqual(subscriber, prev) {
|
||||
err = database.C.Save(&subscriber).Error
|
||||
}
|
||||
|
||||
return subscriber, err
|
||||
}
|
||||
|
||||
func NewNotification(notification models.Notification) error {
|
||||
if err := database.C.Save(¬ification).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := PushNotification(notification)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Unexpected error occurred during the notification.")
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func PushNotification(notification models.Notification) error {
|
||||
for conn := range wsConn[notification.RecipientID] {
|
||||
_ = conn.WriteMessage(1, models.UnifiedCommand{
|
||||
Action: "notifications.new",
|
||||
Payload: notification,
|
||||
}.Marshal())
|
||||
}
|
||||
|
||||
// TODO Detect the push notification is turned off (still push when IsForcePush is on)
|
||||
|
||||
var subscribers []models.NotificationSubscriber
|
||||
if err := database.C.Where(&models.NotificationSubscriber{
|
||||
AccountID: notification.RecipientID,
|
||||
}).Find(&subscribers).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, subscriber := range subscribers {
|
||||
switch subscriber.Provider {
|
||||
case models.NotifySubscriberFirebase:
|
||||
if ExtFire != nil {
|
||||
ctx := context.Background()
|
||||
client, err := ExtFire.Messaging(ctx)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("An error occurred when creating FCM client...")
|
||||
break
|
||||
}
|
||||
|
||||
message := &messaging.Message{
|
||||
Notification: &messaging.Notification{
|
||||
Title: notification.Subject,
|
||||
Body: notification.Content,
|
||||
},
|
||||
Token: subscriber.DeviceToken,
|
||||
}
|
||||
|
||||
if response, err := client.Send(ctx, message); err != nil {
|
||||
log.Warn().Err(err).Msg("An error occurred when notify subscriber via FCM...")
|
||||
} else {
|
||||
log.Debug().
|
||||
Str("response", response).
|
||||
Int("subscriber", int(subscriber.ID)).
|
||||
Msg("Notified subscriber via FCM.")
|
||||
}
|
||||
}
|
||||
case models.NotifySubscriberAPNs:
|
||||
if ExtAPNS != nil {
|
||||
data, err := payload2.
|
||||
NewPayload().
|
||||
AlertTitle(notification.Subject).
|
||||
AlertBody(notification.Content).
|
||||
Sound("default").
|
||||
Category(notification.Type).
|
||||
MarshalJSON()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("An error occurred when preparing to notify subscriber via APNs...")
|
||||
}
|
||||
payload := &apns2.Notification{
|
||||
ApnsID: subscriber.DeviceID,
|
||||
DeviceToken: subscriber.DeviceToken,
|
||||
Topic: viper.GetString("apns_topic"),
|
||||
Payload: data,
|
||||
}
|
||||
|
||||
if resp, err := ExtAPNS.Push(payload); err != nil {
|
||||
log.Warn().Err(err).Msg("An error occurred when notify subscriber via APNs...")
|
||||
} else {
|
||||
log.Debug().
|
||||
Str("reason", resp.Reason).
|
||||
Int("status", resp.StatusCode).
|
||||
Int("subscriber", int(subscriber.ID)).
|
||||
Msg("Notified subscriber via APNs.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
63
pkg/internal/services/perms.go
Normal file
63
pkg/internal/services/perms.go
Normal file
@ -0,0 +1,63 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func HasPermNode(perms map[string]any, requiredKey string, requiredValue any) bool {
|
||||
if heldValue, ok := perms[requiredKey]; ok {
|
||||
return ComparePermNode(heldValue, requiredValue)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ComparePermNode(held any, required any) bool {
|
||||
heldValue := reflect.ValueOf(held)
|
||||
requiredValue := reflect.ValueOf(required)
|
||||
|
||||
switch heldValue.Kind() {
|
||||
case reflect.Int, reflect.Float64:
|
||||
if heldValue.Float() >= requiredValue.Float() {
|
||||
return true
|
||||
}
|
||||
case reflect.String:
|
||||
if heldValue.String() == requiredValue.String() {
|
||||
return true
|
||||
}
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < heldValue.Len(); i++ {
|
||||
if reflect.DeepEqual(heldValue.Index(i).Interface(), required) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
default:
|
||||
if reflect.DeepEqual(held, required) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func FilterPermNodes(tree map[string]any, claims []string) map[string]any {
|
||||
filteredTree := make(map[string]any)
|
||||
|
||||
match := func(claim, permission string) bool {
|
||||
regex := strings.ReplaceAll(claim, "*", ".*")
|
||||
match, _ := regexp.MatchString(fmt.Sprintf("^%s$", regex), permission)
|
||||
return match
|
||||
}
|
||||
|
||||
for _, claim := range claims {
|
||||
for key, value := range tree {
|
||||
if match(claim, key) {
|
||||
filteredTree[key] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return filteredTree
|
||||
}
|
141
pkg/internal/services/realms.go
Normal file
141
pkg/internal/services/realms.go
Normal file
@ -0,0 +1,141 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/database"
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/models"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
func ListCommunityRealm() ([]models.Realm, error) {
|
||||
var realms []models.Realm
|
||||
if err := database.C.Where(&models.Realm{
|
||||
IsCommunity: true,
|
||||
}).Find(&realms).Error; err != nil {
|
||||
return realms, err
|
||||
}
|
||||
|
||||
return realms, nil
|
||||
}
|
||||
|
||||
func ListOwnedRealm(user models.Account) ([]models.Realm, error) {
|
||||
var realms []models.Realm
|
||||
if err := database.C.Where(&models.Realm{AccountID: user.ID}).Find(&realms).Error; err != nil {
|
||||
return realms, err
|
||||
}
|
||||
|
||||
return realms, nil
|
||||
}
|
||||
|
||||
func ListAvailableRealm(user models.Account) ([]models.Realm, error) {
|
||||
var realms []models.Realm
|
||||
var members []models.RealmMember
|
||||
if err := database.C.Where(&models.RealmMember{
|
||||
AccountID: user.ID,
|
||||
}).Find(&members).Error; err != nil {
|
||||
return realms, err
|
||||
}
|
||||
|
||||
idx := lo.Map(members, func(item models.RealmMember, index int) uint {
|
||||
return item.RealmID
|
||||
})
|
||||
|
||||
if err := database.C.Where("id IN ?", idx).Find(&realms).Error; err != nil {
|
||||
return realms, err
|
||||
}
|
||||
|
||||
return realms, nil
|
||||
}
|
||||
|
||||
func GetRealmWithAlias(alias string) (models.Realm, error) {
|
||||
var realm models.Realm
|
||||
if err := database.C.Where(&models.Realm{
|
||||
Alias: alias,
|
||||
}).First(&realm).Error; err != nil {
|
||||
return realm, err
|
||||
}
|
||||
return realm, nil
|
||||
}
|
||||
|
||||
func NewRealm(realm models.Realm, user models.Account) (models.Realm, error) {
|
||||
realm.Members = []models.RealmMember{
|
||||
{AccountID: user.ID, PowerLevel: 100},
|
||||
}
|
||||
|
||||
err := database.C.Save(&realm).Error
|
||||
return realm, err
|
||||
}
|
||||
|
||||
func ListRealmMember(realmId uint) ([]models.RealmMember, error) {
|
||||
var members []models.RealmMember
|
||||
|
||||
if err := database.C.
|
||||
Where(&models.RealmMember{RealmID: realmId}).
|
||||
Preload("Account").
|
||||
Find(&members).Error; err != nil {
|
||||
return members, err
|
||||
}
|
||||
|
||||
return members, nil
|
||||
}
|
||||
|
||||
func GetRealmMember(userId uint, realmId uint) (models.RealmMember, error) {
|
||||
var member models.RealmMember
|
||||
if err := database.C.Where(&models.RealmMember{
|
||||
AccountID: userId,
|
||||
RealmID: realmId,
|
||||
}).Find(&member).Error; err != nil {
|
||||
return member, err
|
||||
}
|
||||
return member, nil
|
||||
}
|
||||
|
||||
func AddRealmMember(user models.Account, affected models.Account, target models.Realm) error {
|
||||
if !target.IsPublic && !target.IsCommunity {
|
||||
if member, err := GetRealmMember(user.ID, target.ID); err != nil {
|
||||
return fmt.Errorf("only realm member can add people: %v", err)
|
||||
} else if member.PowerLevel < 50 {
|
||||
return fmt.Errorf("only realm moderator can add people")
|
||||
}
|
||||
friendship, err := GetFriendWithTwoSides(affected.ID, user.ID)
|
||||
if err != nil || friendship.Status != models.FriendshipActive {
|
||||
return fmt.Errorf("you only can add your friends to your realm")
|
||||
}
|
||||
}
|
||||
|
||||
member := models.RealmMember{
|
||||
RealmID: target.ID,
|
||||
AccountID: affected.ID,
|
||||
}
|
||||
err := database.C.Save(&member).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func RemoveRealmMember(user models.Account, affected models.Account, target models.Realm) error {
|
||||
if user.ID != affected.ID {
|
||||
if member, err := GetRealmMember(user.ID, target.ID); err != nil {
|
||||
return fmt.Errorf("only realm member can remove other member: %v", err)
|
||||
} else if member.PowerLevel < 50 {
|
||||
return fmt.Errorf("only realm moderator can invite people")
|
||||
}
|
||||
}
|
||||
|
||||
var member models.RealmMember
|
||||
if err := database.C.Where(&models.RealmMember{
|
||||
RealmID: target.ID,
|
||||
AccountID: affected.ID,
|
||||
}).First(&member).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return database.C.Delete(&member).Error
|
||||
}
|
||||
|
||||
func EditRealm(realm models.Realm) (models.Realm, error) {
|
||||
err := database.C.Save(&realm).Error
|
||||
return realm, err
|
||||
}
|
||||
|
||||
func DeleteRealm(realm models.Realm) error {
|
||||
return database.C.Delete(&realm).Error
|
||||
}
|
24
pkg/internal/services/ticker_maintainer.go
Normal file
24
pkg/internal/services/ticker_maintainer.go
Normal file
@ -0,0 +1,24 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/database"
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/models"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/viper"
|
||||
"time"
|
||||
)
|
||||
|
||||
func DoAutoSignoff() {
|
||||
duration := time.Duration(viper.GetInt64("security.auto_signoff_duration")) * time.Second
|
||||
divider := time.Now().Add(-duration)
|
||||
|
||||
log.Debug().Time("before", divider).Msg("Now signing off tickets...")
|
||||
|
||||
if tx := database.C.
|
||||
Where("last_grant_at < ?", divider).
|
||||
Delete(&models.AuthTicket{}); tx.Error != nil {
|
||||
log.Error().Err(tx.Error).Msg("An error occurred when running auto sign off...")
|
||||
} else {
|
||||
log.Debug().Int64("affected", tx.RowsAffected).Msg("Auto sign off accomplished.")
|
||||
}
|
||||
}
|
158
pkg/internal/services/ticket.go
Normal file
158
pkg/internal/services/ticket.go
Normal file
@ -0,0 +1,158 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/database"
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/models"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
func DetectRisk(user models.Account, ip, ua string) bool {
|
||||
var availableFactor int64
|
||||
if err := database.C.
|
||||
Where(models.AuthFactor{AccountID: user.ID}).
|
||||
Where("type != ?", models.PasswordAuthFactor).
|
||||
Model(models.AuthFactor{}).
|
||||
Where(&availableFactor); err != nil || availableFactor <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
var secureFactor int64
|
||||
if err := database.C.
|
||||
Where(models.AuthTicket{AccountID: user.ID, IpAddress: ip}).
|
||||
Where("available_at IS NOT NULL").
|
||||
Model(models.AuthTicket{}).
|
||||
Count(&secureFactor).Error; err == nil {
|
||||
if secureFactor >= 1 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func NewTicket(user models.Account, ip, ua string) (models.AuthTicket, error) {
|
||||
var ticket models.AuthTicket
|
||||
if err := database.C.
|
||||
Where("account_id = ? AND expired_at < ? AND available_at IS NULL", time.Now(), user.ID).
|
||||
First(&ticket).Error; err == nil {
|
||||
return ticket, nil
|
||||
}
|
||||
|
||||
requireMFA := DetectRisk(user, ip, ua)
|
||||
if count := CountUserFactor(user.ID); count <= 1 {
|
||||
requireMFA = false
|
||||
}
|
||||
|
||||
ticket = models.AuthTicket{
|
||||
Claims: []string{"*"},
|
||||
Audiences: []string{"passport"},
|
||||
IpAddress: ip,
|
||||
UserAgent: ua,
|
||||
RequireMFA: requireMFA,
|
||||
RequireAuthenticate: true,
|
||||
ExpiredAt: nil,
|
||||
AvailableAt: nil,
|
||||
AccountID: user.ID,
|
||||
}
|
||||
|
||||
err := database.C.Save(&ticket).Error
|
||||
|
||||
return ticket, err
|
||||
}
|
||||
|
||||
func NewOauthTicket(
|
||||
user models.Account,
|
||||
client models.ThirdClient,
|
||||
claims, audiences []string,
|
||||
ip, ua string,
|
||||
) (models.AuthTicket, error) {
|
||||
ticket := models.AuthTicket{
|
||||
Claims: claims,
|
||||
Audiences: audiences,
|
||||
IpAddress: ip,
|
||||
UserAgent: ua,
|
||||
RequireMFA: DetectRisk(user, ip, ua),
|
||||
GrantToken: lo.ToPtr(uuid.NewString()),
|
||||
AccessToken: lo.ToPtr(uuid.NewString()),
|
||||
RefreshToken: lo.ToPtr(uuid.NewString()),
|
||||
AvailableAt: lo.ToPtr(time.Now()),
|
||||
ExpiredAt: lo.ToPtr(time.Now()),
|
||||
ClientID: &client.ID,
|
||||
AccountID: user.ID,
|
||||
}
|
||||
|
||||
if err := database.C.Save(&ticket).Error; err != nil {
|
||||
return ticket, err
|
||||
}
|
||||
|
||||
return ticket, nil
|
||||
}
|
||||
|
||||
func ActiveTicketWithPassword(ticket models.AuthTicket, password string) (models.AuthTicket, error) {
|
||||
if ticket.AvailableAt != nil {
|
||||
return ticket, nil
|
||||
} else if !ticket.RequireAuthenticate {
|
||||
return ticket, nil
|
||||
}
|
||||
|
||||
if factor, err := GetPasswordTypeFactor(ticket.AccountID); err != nil {
|
||||
return ticket, fmt.Errorf("unable to active ticket: %v", err)
|
||||
} else if err = CheckFactor(factor, password); err != nil {
|
||||
return ticket, err
|
||||
}
|
||||
|
||||
ticket.RequireAuthenticate = false
|
||||
|
||||
if !ticket.RequireAuthenticate && !ticket.RequireMFA {
|
||||
ticket.AvailableAt = lo.ToPtr(time.Now())
|
||||
ticket.GrantToken = lo.ToPtr(uuid.NewString())
|
||||
ticket.AccessToken = lo.ToPtr(uuid.NewString())
|
||||
ticket.RefreshToken = lo.ToPtr(uuid.NewString())
|
||||
}
|
||||
|
||||
if err := database.C.Save(&ticket).Error; err != nil {
|
||||
return ticket, err
|
||||
}
|
||||
|
||||
return ticket, nil
|
||||
}
|
||||
|
||||
func ActiveTicketWithMFA(ticket models.AuthTicket, factor models.AuthFactor, code string) (models.AuthTicket, error) {
|
||||
if ticket.AvailableAt != nil {
|
||||
return ticket, nil
|
||||
} else if !ticket.RequireMFA {
|
||||
return ticket, nil
|
||||
}
|
||||
|
||||
if err := CheckFactor(factor, code); err != nil {
|
||||
return ticket, fmt.Errorf("invalid code: %v", err)
|
||||
}
|
||||
|
||||
ticket.RequireMFA = false
|
||||
|
||||
if !ticket.RequireAuthenticate && !ticket.RequireMFA {
|
||||
ticket.AvailableAt = lo.ToPtr(time.Now())
|
||||
ticket.GrantToken = lo.ToPtr(uuid.NewString())
|
||||
ticket.AccessToken = lo.ToPtr(uuid.NewString())
|
||||
ticket.RefreshToken = lo.ToPtr(uuid.NewString())
|
||||
}
|
||||
|
||||
if err := database.C.Save(&ticket).Error; err != nil {
|
||||
return ticket, err
|
||||
}
|
||||
|
||||
return ticket, nil
|
||||
}
|
||||
|
||||
func RegenSession(ticket models.AuthTicket) (models.AuthTicket, error) {
|
||||
ticket.GrantToken = lo.ToPtr(uuid.NewString())
|
||||
ticket.AccessToken = lo.ToPtr(uuid.NewString())
|
||||
ticket.RefreshToken = lo.ToPtr(uuid.NewString())
|
||||
err := database.C.Save(&ticket).Error
|
||||
return ticket, err
|
||||
}
|
29
pkg/internal/services/ticket_queries.go
Normal file
29
pkg/internal/services/ticket_queries.go
Normal file
@ -0,0 +1,29 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/database"
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/models"
|
||||
)
|
||||
|
||||
func GetTicket(id uint) (models.AuthTicket, error) {
|
||||
var ticket models.AuthTicket
|
||||
if err := database.C.
|
||||
Where(&models.AuthTicket{BaseModel: models.BaseModel{ID: id}}).
|
||||
First(&ticket).Error; err != nil {
|
||||
return ticket, err
|
||||
}
|
||||
|
||||
return ticket, nil
|
||||
}
|
||||
|
||||
func GetTicketWithToken(tokenId string) (models.AuthTicket, error) {
|
||||
var ticket models.AuthTicket
|
||||
if err := database.C.
|
||||
Where(models.AuthTicket{AccessToken: &tokenId}).
|
||||
Or(models.AuthTicket{RefreshToken: &tokenId}).
|
||||
First(&ticket).Error; err != nil {
|
||||
return ticket, err
|
||||
}
|
||||
|
||||
return ticket, nil
|
||||
}
|
98
pkg/internal/services/ticket_token.go
Normal file
98
pkg/internal/services/ticket_token.go
Normal file
@ -0,0 +1,98 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/database"
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/models"
|
||||
"github.com/samber/lo"
|
||||
"github.com/spf13/viper"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
func GetToken(ticket models.AuthTicket) (string, string, error) {
|
||||
var refresh, access string
|
||||
if err := ticket.IsAvailable(); err != nil {
|
||||
return refresh, access, err
|
||||
}
|
||||
if ticket.AccessToken == nil || ticket.RefreshToken == nil {
|
||||
return refresh, access, fmt.Errorf("unable to encode token, access or refresh token id missing")
|
||||
}
|
||||
|
||||
accessDuration := time.Duration(viper.GetInt64("security.access_token_duration")) * time.Second
|
||||
refreshDuration := time.Duration(viper.GetInt64("security.refresh_token_duration")) * time.Second
|
||||
|
||||
var err error
|
||||
sub := strconv.Itoa(int(ticket.AccountID))
|
||||
sed := strconv.Itoa(int(ticket.ID))
|
||||
access, err = EncodeJwt(*ticket.AccessToken, JwtAccessType, sub, sed, ticket.Audiences, time.Now().Add(accessDuration))
|
||||
if err != nil {
|
||||
return refresh, access, err
|
||||
}
|
||||
refresh, err = EncodeJwt(*ticket.RefreshToken, JwtRefreshType, sub, sed, ticket.Audiences, time.Now().Add(refreshDuration))
|
||||
if err != nil {
|
||||
return refresh, access, err
|
||||
}
|
||||
|
||||
ticket.LastGrantAt = lo.ToPtr(time.Now())
|
||||
database.C.Save(&ticket)
|
||||
|
||||
return access, refresh, nil
|
||||
}
|
||||
|
||||
func ExchangeToken(token string) (string, string, error) {
|
||||
var ticket models.AuthTicket
|
||||
if err := database.C.Where(models.AuthTicket{GrantToken: &token}).First(&ticket).Error; err != nil {
|
||||
return "", "", err
|
||||
} else if ticket.LastGrantAt != nil {
|
||||
return "", "", fmt.Errorf("ticket was granted the first token, use refresh token instead")
|
||||
} else if len(ticket.Audiences) > 1 {
|
||||
return "", "", fmt.Errorf("should use authorization code grant type")
|
||||
}
|
||||
|
||||
return GetToken(ticket)
|
||||
}
|
||||
|
||||
func ExchangeOauthToken(clientId, clientSecret, redirectUri, token string) (string, string, error) {
|
||||
var client models.ThirdClient
|
||||
if err := database.C.Where(models.ThirdClient{Alias: clientId}).First(&client).Error; err != nil {
|
||||
return "", "", err
|
||||
} else if client.Secret != clientSecret {
|
||||
return "", "", fmt.Errorf("invalid client secret")
|
||||
} else if !client.IsDraft && !lo.Contains(client.Callbacks, redirectUri) {
|
||||
return "", "", fmt.Errorf("invalid redirect uri")
|
||||
}
|
||||
|
||||
var ticket models.AuthTicket
|
||||
if err := database.C.Where(models.AuthTicket{GrantToken: &token}).First(&ticket).Error; err != nil {
|
||||
return "", "", err
|
||||
} else if ticket.LastGrantAt != nil {
|
||||
return "", "", fmt.Errorf("ticket was granted the first token, use refresh token instead")
|
||||
}
|
||||
|
||||
return GetToken(ticket)
|
||||
}
|
||||
|
||||
func RefreshToken(token string) (string, string, error) {
|
||||
parseInt := func(str string) int {
|
||||
val, _ := strconv.Atoi(str)
|
||||
return val
|
||||
}
|
||||
|
||||
var ticket models.AuthTicket
|
||||
if claims, err := DecodeJwt(token); err != nil {
|
||||
return "404", "403", err
|
||||
} else if claims.Type != JwtRefreshType {
|
||||
return "404", "403", fmt.Errorf("invalid token type, expected refresh token")
|
||||
} else if err := database.C.Where(models.AuthTicket{
|
||||
BaseModel: models.BaseModel{ID: uint(parseInt(claims.SessionID))},
|
||||
}).First(&ticket).Error; err != nil {
|
||||
return "404", "403", err
|
||||
}
|
||||
|
||||
if ticket, err := RegenSession(ticket); err != nil {
|
||||
return "404", "403", err
|
||||
} else {
|
||||
return GetToken(ticket)
|
||||
}
|
||||
}
|
92
pkg/internal/services/tokens.go
Normal file
92
pkg/internal/services/tokens.go
Normal file
@ -0,0 +1,92 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/database"
|
||||
"git.solsynth.dev/hydrogen/passport/pkg/models"
|
||||
"github.com/google/uuid"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
const ConfirmRegistrationTemplate = `Dear %s,
|
||||
|
||||
Thank you for choosing to register with %s. We are excited to welcome you to our community and appreciate your trust in us.
|
||||
|
||||
Your registration details have been successfully received, and you are now a valued member of %s. Here are the confirm link of your registration:
|
||||
|
||||
%s
|
||||
|
||||
As a confirmed registered member, you will have access to all our services.
|
||||
We encourage you to explore our services and take full advantage of the resources available to you.
|
||||
|
||||
Once again, thank you for choosing us. We look forward to serving you and hope you have a positive experience with us.
|
||||
|
||||
Best regards,
|
||||
%s`
|
||||
|
||||
func ValidateMagicToken(code string, mode models.MagicTokenType) (models.MagicToken, error) {
|
||||
var tk models.MagicToken
|
||||
if err := database.C.Where(models.MagicToken{Code: code, Type: mode}).First(&tk).Error; err != nil {
|
||||
return tk, err
|
||||
} else if tk.ExpiredAt != nil && time.Now().Unix() >= tk.ExpiredAt.Unix() {
|
||||
return tk, fmt.Errorf("token has been expired")
|
||||
}
|
||||
|
||||
return tk, nil
|
||||
}
|
||||
|
||||
func NewMagicToken(mode models.MagicTokenType, assignTo *models.Account, expiredAt *time.Time) (models.MagicToken, error) {
|
||||
var uid uint
|
||||
if assignTo != nil {
|
||||
uid = assignTo.ID
|
||||
}
|
||||
|
||||
token := models.MagicToken{
|
||||
Code: strings.Replace(uuid.NewString(), "-", "", -1),
|
||||
Type: mode,
|
||||
AssignTo: &uid,
|
||||
ExpiredAt: expiredAt,
|
||||
}
|
||||
|
||||
if err := database.C.Save(&token).Error; err != nil {
|
||||
return token, err
|
||||
} else {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
func NotifyMagicToken(token models.MagicToken) error {
|
||||
if token.AssignTo == nil {
|
||||
return fmt.Errorf("could notify a non-assign magic token")
|
||||
}
|
||||
|
||||
var user models.Account
|
||||
if err := database.C.Where(&models.Account{
|
||||
BaseModel: models.BaseModel{ID: *token.AssignTo},
|
||||
}).Preload("Contacts").First(&user).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var subject string
|
||||
var content string
|
||||
switch token.Type {
|
||||
case models.ConfirmMagicToken:
|
||||
link := fmt.Sprintf("https://%s/me/confirm?tk=%s", viper.GetString("domain"), token.Code)
|
||||
subject = fmt.Sprintf("[%s] Confirm your registration", viper.GetString("name"))
|
||||
content = fmt.Sprintf(
|
||||
ConfirmRegistrationTemplate,
|
||||
user.Name,
|
||||
viper.GetString("name"),
|
||||
viper.GetString("maintainer"),
|
||||
link,
|
||||
viper.GetString("maintainer"),
|
||||
)
|
||||
default:
|
||||
return fmt.Errorf("unsupported magic token type to notify")
|
||||
}
|
||||
|
||||
return SendMail(user.GetPrimaryEmail().Content, subject, content)
|
||||
}
|
Reference in New Issue
Block a user