🎨 Update project structure

This commit is contained in:
2024-06-16 23:17:32 +08:00
parent 0695338fa1
commit 45048ea814
103 changed files with 138 additions and 40 deletions

View File

@ -0,0 +1,123 @@
package services
import (
"fmt"
"github.com/spf13/viper"
"gorm.io/datatypes"
"time"
"git.solsynth.dev/hydrogen/passport/pkg/database"
"git.solsynth.dev/hydrogen/passport/pkg/models"
"github.com/google/uuid"
"github.com/samber/lo"
"gorm.io/gorm"
)
func GetAccount(id uint) (models.Account, error) {
var account models.Account
if err := database.C.Where(models.Account{
BaseModel: models.BaseModel{ID: id},
}).First(&account).Error; err != nil {
return account, err
}
return account, nil
}
func LookupAccount(probe string) (models.Account, error) {
var account models.Account
if err := database.C.Where(models.Account{Name: probe}).First(&account).Error; err == nil {
return account, nil
}
var contact models.AccountContact
if err := database.C.Where(models.AccountContact{Content: probe}).First(&contact).Error; err == nil {
if err := database.C.
Where(models.Account{
BaseModel: models.BaseModel{ID: contact.AccountID},
}).First(&account).Error; err == nil {
return account, err
}
}
return account, fmt.Errorf("account was not found")
}
func CreateAccount(name, nick, email, password string) (models.Account, error) {
user := models.Account{
Name: name,
Nick: nick,
Profile: models.AccountProfile{
Experience: 100,
},
Factors: []models.AuthFactor{
{
Type: models.PasswordAuthFactor,
Secret: HashPassword(password),
},
{
Type: models.EmailPasswordFactor,
Secret: uuid.NewString()[:8],
},
},
Contacts: []models.AccountContact{
{
Type: models.EmailAccountContact,
Content: email,
IsPrimary: true,
VerifiedAt: nil,
},
},
PermNodes: datatypes.JSONMap(viper.GetStringMap("permissions.default")),
ConfirmedAt: nil,
}
if err := database.C.Create(&user).Error; err != nil {
return user, err
}
if tk, err := NewMagicToken(models.ConfirmMagicToken, &user, nil); err != nil {
return user, err
} else if err := NotifyMagicToken(tk); err != nil {
return user, err
}
return user, nil
}
func ConfirmAccount(code string) error {
token, err := ValidateMagicToken(code, models.ConfirmMagicToken)
if err != nil {
return err
}
var user models.Account
if err := database.C.Where(&models.Account{
BaseModel: models.BaseModel{ID: *token.AssignTo},
}).First(&user).Error; err != nil {
return err
}
return database.C.Transaction(func(tx *gorm.DB) error {
user.ConfirmedAt = lo.ToPtr(time.Now())
for k, v := range viper.GetStringMap("permissions.verified") {
if val, ok := user.PermNodes[k]; !ok {
user.PermNodes[k] = v
} else if !ComparePermNode(val, v) {
user.PermNodes[k] = v
}
}
if err := database.C.Delete(&token).Error; err != nil {
return err
}
if err := database.C.Save(&user).Error; err != nil {
return err
}
InvalidAuthCacheWithUser(user.ID)
return nil
})
}

View File

@ -0,0 +1,125 @@
package services
import (
"fmt"
"sync"
"time"
jsoniter "github.com/json-iterator/go"
"git.solsynth.dev/hydrogen/passport/pkg/models"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
var (
authContextMutex sync.Mutex
authContextCache = make(map[string]models.AuthContext)
)
func Authenticate(access, refresh string, depth int) (ctx models.AuthContext, perms map[string]any, newAccess, newRefresh string, err error) {
var claims PayloadClaims
claims, err = DecodeJwt(access)
if err != nil {
if len(refresh) > 0 && depth < 1 {
// Auto refresh and retry
newAccess, newRefresh, err = RefreshToken(refresh)
if err == nil {
return Authenticate(newAccess, newRefresh, depth+1)
}
}
err = fiber.NewError(fiber.StatusUnauthorized, fmt.Sprintf("invalid auth key: %v", err))
return
}
newAccess = access
newRefresh = refresh
if ctx, err = GetAuthContext(claims.ID); err == nil {
var heldPerms map[string]any
rawHeldPerms, _ := jsoniter.Marshal(ctx.Account.PermNodes)
_ = jsoniter.Unmarshal(rawHeldPerms, &heldPerms)
perms = FilterPermNodes(heldPerms, ctx.Ticket.Claims)
return
}
err = fiber.NewError(fiber.StatusUnauthorized, err.Error())
return
}
func GetAuthContext(jti string) (models.AuthContext, error) {
var err error
var ctx models.AuthContext
if val, ok := authContextCache[jti]; ok {
ctx = val
ctx.LastUsedAt = time.Now()
authContextMutex.Lock()
authContextCache[jti] = ctx
authContextMutex.Unlock()
log.Debug().Str("jti", jti).Msg("Used an auth context cache")
} else {
ctx, err = CacheAuthContext(jti)
log.Debug().Str("jti", jti).Msg("Created a new auth context cache")
}
return ctx, err
}
func CacheAuthContext(jti string) (models.AuthContext, error) {
var ctx models.AuthContext
// Query data from primary database
ticket, err := GetTicketWithToken(jti)
if err != nil {
return ctx, fmt.Errorf("invalid auth ticket: %v", err)
} else if err := ticket.IsAvailable(); err != nil {
return ctx, fmt.Errorf("unavailable auth ticket: %v", err)
}
user, err := GetAccount(ticket.AccountID)
if err != nil {
return ctx, fmt.Errorf("invalid account: %v", err)
}
ctx = models.AuthContext{
Ticket: ticket,
Account: user,
LastUsedAt: time.Now(),
}
// Put the data into memory for cache
authContextMutex.Lock()
authContextCache[jti] = ctx
authContextMutex.Unlock()
return ctx, nil
}
func RecycleAuthContext() {
if len(authContextCache) == 0 {
return
}
affected := 0
for key, val := range authContextCache {
if val.LastUsedAt.Add(60*time.Second).Unix() < time.Now().Unix() {
affected++
authContextMutex.Lock()
delete(authContextCache, key)
authContextMutex.Unlock()
}
}
log.Debug().Int("affected", affected).Msg("Recycled auth context...")
}
func InvalidAuthCacheWithUser(userId uint) {
for key, val := range authContextCache {
if val.Account.ID == userId {
authContextMutex.Lock()
delete(authContextCache, key)
authContextMutex.Unlock()
}
}
}

View File

@ -0,0 +1,15 @@
package services
import (
"git.solsynth.dev/hydrogen/passport/pkg/database"
"git.solsynth.dev/hydrogen/passport/pkg/models"
)
func GrantBadge(user models.Account, badge models.Badge) error {
badge.AccountID = user.ID
return database.C.Save(badge).Error
}
func RevokeBadge(badge models.Badge) error {
return database.C.Delete(&badge).Error
}

View File

@ -0,0 +1,23 @@
package services
import (
"git.solsynth.dev/hydrogen/passport/pkg/database"
"github.com/rs/zerolog/log"
"time"
)
func DoAutoDatabaseCleanup() {
deadline := time.Now().Add(60 * time.Minute)
log.Debug().Time("deadline", deadline).Msg("Now cleaning up entire database...")
var count int64
for _, model := range database.AutoMaintainRange {
tx := database.C.Unscoped().Delete(model, "deleted_at >= ?", deadline)
if tx.Error != nil {
log.Error().Err(tx.Error).Msg("An error occurred when running auth context cleanup...")
}
count += tx.RowsAffected
}
log.Debug().Int64("affected", count).Msg("Clean up entire database accomplished.")
}

View File

@ -0,0 +1,32 @@
package services
import (
"fmt"
"git.solsynth.dev/hydrogen/passport/pkg/database"
"git.solsynth.dev/hydrogen/passport/pkg/models"
)
func GetThirdClient(id string) (models.ThirdClient, error) {
var client models.ThirdClient
if err := database.C.Where(&models.ThirdClient{
Alias: id,
}).First(&client).Error; err != nil {
return client, err
}
return client, nil
}
func GetThirdClientWithSecret(id, secret string) (models.ThirdClient, error) {
client, err := GetThirdClient(id)
if err != nil {
return client, err
}
if client.Secret != secret {
return client, fmt.Errorf("invalid client secret")
}
return client, nil
}

View File

@ -0,0 +1,31 @@
package services
import (
"sync"
"git.solsynth.dev/hydrogen/passport/pkg/models"
"github.com/gofiber/contrib/websocket"
)
var (
wsMutex sync.Mutex
wsConn = make(map[uint]map[*websocket.Conn]bool)
)
func ClientRegister(user models.Account, conn *websocket.Conn) {
wsMutex.Lock()
if wsConn[user.ID] == nil {
wsConn[user.ID] = make(map[*websocket.Conn]bool)
}
wsConn[user.ID][conn] = true
wsMutex.Unlock()
}
func ClientUnregister(user models.Account, conn *websocket.Conn) {
wsMutex.Lock()
if wsConn[user.ID] == nil {
wsConn[user.ID] = make(map[*websocket.Conn]bool)
}
delete(wsConn[user.ID], conn)
wsMutex.Unlock()
}

View File

@ -0,0 +1,82 @@
package services
import (
"git.solsynth.dev/hydrogen/passport/pkg/models"
"github.com/gofiber/contrib/websocket"
"github.com/gofiber/fiber/v2"
"time"
)
type kexRequest struct {
OwnerID uint
Conn *websocket.Conn
Deadline time.Time
}
var kexRequests = make(map[string]map[string]kexRequest)
func KexRequest(conn *websocket.Conn, requestId, keypairId, algorithm string, ownerId uint, deadline int64) {
if kexRequests[keypairId] == nil {
kexRequests[keypairId] = make(map[string]kexRequest)
}
ddl := time.Now().Add(time.Second * time.Duration(deadline))
request := kexRequest{
OwnerID: ownerId,
Conn: conn,
Deadline: ddl,
}
flag := false
for c := range wsConn[ownerId] {
if c == conn {
continue
}
if c.WriteMessage(1, models.UnifiedCommand{
Action: "kex.request",
Payload: fiber.Map{
"request_id": requestId,
"keypair_id": keypairId,
"algorithm": algorithm,
"owner_id": ownerId,
"deadline": deadline,
},
}.Marshal()) == nil {
flag = true
}
}
if flag {
kexRequests[keypairId][requestId] = request
}
}
func KexProvide(userId uint, requestId, keypairId string, pkt []byte) {
if kexRequests[keypairId] == nil {
return
}
val, ok := kexRequests[keypairId][requestId]
if !ok {
return
} else if val.OwnerID != userId {
return
} else {
_ = val.Conn.WriteMessage(1, pkt)
}
}
func KexCleanup() {
if len(kexRequests) <= 0 {
return
}
for kp, data := range kexRequests {
for idx, req := range data {
if req.Deadline.Unix() <= time.Now().Unix() {
delete(kexRequests[kp], idx)
}
}
}
}

View File

@ -0,0 +1,12 @@
package services
import "golang.org/x/crypto/bcrypt"
func HashPassword(raw string) string {
data, _ := bcrypt.GenerateFromPassword([]byte(raw), 12)
return string(data)
}
func VerifyPassword(text string, password string) bool {
return bcrypt.CompareHashAndPassword([]byte(password), []byte(text)) == nil
}

View File

@ -0,0 +1,20 @@
package services
import (
"git.solsynth.dev/hydrogen/passport/pkg/database"
"git.solsynth.dev/hydrogen/passport/pkg/models"
)
func AddEvent(user models.Account, event, target, ip, ua string) models.ActionEvent {
evt := models.ActionEvent{
Type: event,
Target: target,
IpAddress: ip,
UserAgent: ua,
AccountID: user.ID,
}
database.C.Save(&evt)
return evt
}

View File

@ -0,0 +1,25 @@
package services
import (
"github.com/sideshow/apns2"
"github.com/sideshow/apns2/token"
"github.com/spf13/viper"
)
// ExtAPNS is Apple Notification Services client
var ExtAPNS *apns2.Client
func SetupAPNS() error {
authKey, err := token.AuthKeyFromFile(viper.GetString("apns_credentials"))
if err != nil {
return err
}
ExtAPNS = apns2.NewTokenClient(&token.Token{
AuthKey: authKey,
KeyID: viper.GetString("apns_credentials_key"),
TeamID: viper.GetString("apns_credentials_team"),
}).Production()
return nil
}

View File

@ -0,0 +1,23 @@
package services
import (
"context"
firebase "firebase.google.com/go"
"github.com/spf13/viper"
"google.golang.org/api/option"
)
// ExtFire is the firebase app client
var ExtFire *firebase.App
func SetupFirebase() error {
opt := option.WithCredentialsFile(viper.GetString("firebase_credentials"))
app, err := firebase.NewApp(context.Background(), nil, opt)
if err != nil {
return err
} else {
ExtFire = app
}
return nil
}

View File

@ -0,0 +1,109 @@
package services
import (
"fmt"
"github.com/samber/lo"
"git.solsynth.dev/hydrogen/passport/pkg/database"
"git.solsynth.dev/hydrogen/passport/pkg/models"
"github.com/google/uuid"
"github.com/spf13/viper"
)
const EmailPasswordTemplate = `Dear %s,
We hope this message finds you well.
As part of our ongoing commitment to ensuring the security of your account, we require you to complete the login process by entering the verification code below:
Your Login Verification Code: %s
Please use the provided code within the next 2 hours to complete your login.
If you did not request this code, please update your information, maybe your username or email has been leak.
Thank you for your cooperation in helping us maintain the security of your account.
Best regards,
%s`
func GetPasswordTypeFactor(userId uint) (models.AuthFactor, error) {
var factor models.AuthFactor
err := database.C.Where(models.AuthFactor{
Type: models.PasswordAuthFactor,
AccountID: userId,
}).First(&factor).Error
return factor, err
}
func GetFactor(id uint) (models.AuthFactor, error) {
var factor models.AuthFactor
err := database.C.Where(models.AuthFactor{
BaseModel: models.BaseModel{ID: id},
}).First(&factor).Error
return factor, err
}
func ListUserFactor(userId uint) ([]models.AuthFactor, error) {
var factors []models.AuthFactor
err := database.C.Where(models.AuthFactor{
AccountID: userId,
}).Find(&factors).Error
return factors, err
}
func CountUserFactor(userId uint) int64 {
var count int64
database.C.Where(models.AuthFactor{
AccountID: userId,
}).Model(&models.AuthFactor{}).Count(&count)
return count
}
func GetFactorCode(factor models.AuthFactor) (bool, error) {
switch factor.Type {
case models.EmailPasswordFactor:
var user models.Account
if err := database.C.Where(&models.Account{
BaseModel: models.BaseModel{ID: factor.AccountID},
}).Preload("Contacts").First(&user).Error; err != nil {
return true, err
}
factor.Secret = uuid.NewString()[:6]
if err := database.C.Save(&factor).Error; err != nil {
return true, err
}
subject := fmt.Sprintf("[%s] Login verification code", viper.GetString("name"))
content := fmt.Sprintf(EmailPasswordTemplate, user.Name, factor.Secret, viper.GetString("maintainer"))
if err := SendMail(user.GetPrimaryEmail().Content, subject, content); err != nil {
return true, err
}
return true, nil
default:
return false, nil
}
}
func CheckFactor(factor models.AuthFactor, code string) error {
switch factor.Type {
case models.PasswordAuthFactor:
return lo.Ternary(
VerifyPassword(code, factor.Secret),
nil,
fmt.Errorf("invalid password"),
)
case models.EmailPasswordFactor:
return lo.Ternary(
code == factor.Secret,
nil,
fmt.Errorf("invalid verification code"),
)
}
return nil
}

View File

@ -0,0 +1,125 @@
package services
import (
"errors"
"fmt"
"git.solsynth.dev/hydrogen/passport/pkg/database"
"git.solsynth.dev/hydrogen/passport/pkg/models"
"gorm.io/gorm"
)
func ListAllFriend(anyside models.Account) ([]models.AccountFriendship, error) {
var relationships []models.AccountFriendship
if err := database.C.
Where("account_id = ? OR related_id = ?", anyside.ID, anyside.ID).
Preload("Account").
Preload("Related").
Find(&relationships).Error; err != nil {
return relationships, err
}
return relationships, nil
}
func ListFriend(anyside models.Account, status models.FriendshipStatus) ([]models.AccountFriendship, error) {
var relationships []models.AccountFriendship
if err := database.C.
Where("(account_id = ? OR related_id = ?) AND status = ?", anyside.ID, anyside.ID, status).
Preload("Account").
Preload("Related").
Find(&relationships).Error; err != nil {
return relationships, err
}
return relationships, nil
}
func GetFriend(anysideId uint) (models.AccountFriendship, error) {
var relationship models.AccountFriendship
if err := database.C.
Where(&models.AccountFriendship{AccountID: anysideId}).
Or(&models.AccountFriendship{RelatedID: anysideId}).
Preload("Account").
Preload("Related").
First(&relationship).Error; err != nil {
return relationship, err
}
return relationship, nil
}
func GetFriendWithTwoSides(userId, relatedId uint, noPreload ...bool) (models.AccountFriendship, error) {
var tx *gorm.DB
if len(noPreload) > 0 && noPreload[0] {
tx = database.C
} else {
tx = database.C.Preload("Account").Preload("Related")
}
var relationship models.AccountFriendship
if err := tx.
Where(&models.AccountFriendship{AccountID: userId, RelatedID: relatedId}).
Or(&models.AccountFriendship{RelatedID: userId, AccountID: relatedId}).
First(&relationship).Error; err != nil {
return relationship, err
}
return relationship, nil
}
func NewFriend(user models.Account, related models.Account, status models.FriendshipStatus) (models.AccountFriendship, error) {
relationship := models.AccountFriendship{
AccountID: user.ID,
RelatedID: related.ID,
Status: status,
}
if user.ID == related.ID {
return relationship, fmt.Errorf("you cannot make friendship with yourself")
} else if _, err := GetFriendWithTwoSides(user.ID, related.ID, true); err == nil || !errors.Is(err, gorm.ErrRecordNotFound) {
return relationship, fmt.Errorf("you already have a friendship with him or her")
}
if err := database.C.Save(&relationship).Error; err != nil {
return relationship, err
} else {
_ = NewNotification(models.Notification{
Subject: fmt.Sprintf("New friend request from %s", user.Name),
Content: fmt.Sprintf("You got a new friend request from %s. Go to your settings and decide how to deal it.", user.Nick),
RecipientID: related.ID,
})
}
return relationship, nil
}
func EditFriendWithCheck(relationship models.AccountFriendship, user models.Account, originalStatus models.FriendshipStatus) (models.AccountFriendship, error) {
if relationship.Status != originalStatus {
if originalStatus == models.FriendshipBlocked && relationship.BlockedBy != nil && user.ID != *relationship.BlockedBy {
return relationship, fmt.Errorf("the friendship has been blocked by the otherside, you cannot modify it status")
}
if relationship.Status == models.FriendshipPending && relationship.RelatedID != user.ID {
return relationship, fmt.Errorf("only related person can accept friendship")
}
}
if originalStatus != models.FriendshipBlocked && relationship.Status == models.FriendshipBlocked {
relationship.BlockedBy = &user.ID
}
return EditFriend(relationship)
}
func EditFriend(relationship models.AccountFriendship) (models.AccountFriendship, error) {
if err := database.C.Save(&relationship).Error; err != nil {
return relationship, err
}
return relationship, nil
}
func DeleteFriend(relationship models.AccountFriendship) error {
if err := database.C.Delete(&relationship).Error; err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,81 @@
package services
import (
"fmt"
"github.com/gofiber/fiber/v2"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/spf13/viper"
)
var CookieAccessKey = "passport_auth_key"
var CookieRefreshKey = "passport_refresh_key"
type PayloadClaims struct {
jwt.RegisteredClaims
SessionID string `json:"sed"`
Type string `json:"typ"`
}
const (
JwtAccessType = "access"
JwtRefreshType = "refresh"
)
func EncodeJwt(id string, typ, sub, sed string, aud []string, exp time.Time) (string, error) {
tk := jwt.NewWithClaims(jwt.SigningMethodHS512, PayloadClaims{
jwt.RegisteredClaims{
Subject: sub,
Audience: aud,
Issuer: fmt.Sprintf("https://%s", viper.GetString("domain")),
ExpiresAt: jwt.NewNumericDate(exp),
NotBefore: jwt.NewNumericDate(time.Now()),
IssuedAt: jwt.NewNumericDate(time.Now()),
ID: id,
},
sed,
typ,
})
return tk.SignedString([]byte(viper.GetString("secret")))
}
func DecodeJwt(str string) (PayloadClaims, error) {
var claims PayloadClaims
tk, err := jwt.ParseWithClaims(str, &claims, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(viper.GetString("secret")), nil
})
if err != nil {
return claims, err
}
if data, ok := tk.Claims.(*PayloadClaims); ok {
return *data, nil
} else {
return claims, fmt.Errorf("unexpected token payload: not payload claims type")
}
}
func SetJwtCookieSet(c *fiber.Ctx, access, refresh string) {
c.Cookie(&fiber.Cookie{
Name: CookieAccessKey,
Value: access,
Domain: viper.GetString("security.cookie_domain"),
SameSite: viper.GetString("security.cookie_samesite"),
Expires: time.Now().Add(60 * time.Minute),
Path: "/",
})
c.Cookie(&fiber.Cookie{
Name: CookieRefreshKey,
Value: refresh,
Domain: viper.GetString("security.cookie_domain"),
SameSite: viper.GetString("security.cookie_samesite"),
Expires: time.Now().Add(24 * 30 * time.Hour),
Path: "/",
})
}

View File

@ -0,0 +1,51 @@
package services
import (
"crypto/tls"
"fmt"
"net/smtp"
"net/textproto"
"github.com/jordan-wright/email"
"github.com/spf13/viper"
)
func SendMail(target string, subject string, content string) error {
mail := &email.Email{
To: []string{target},
From: viper.GetString("mailer.name"),
Subject: subject,
Text: []byte(content),
Headers: textproto.MIMEHeader{},
}
return mail.SendWithTLS(
fmt.Sprintf("%s:%d", viper.GetString("mailer.smtp_host"), viper.GetInt("mailer.smtp_port")),
smtp.PlainAuth(
"",
viper.GetString("mailer.username"),
viper.GetString("mailer.password"),
viper.GetString("mailer.smtp_host"),
),
&tls.Config{ServerName: viper.GetString("mailer.smtp_host")},
)
}
func SendMailHTML(target string, subject string, content string) error {
mail := &email.Email{
To: []string{target},
From: viper.GetString("mailer.name"),
Subject: subject,
HTML: []byte(content),
Headers: textproto.MIMEHeader{},
}
return mail.SendWithTLS(
fmt.Sprintf("%s:%d", viper.GetString("mailer.smtp_host"), viper.GetInt("mailer.smtp_port")),
smtp.PlainAuth(
"",
viper.GetString("mailer.username"),
viper.GetString("mailer.password"),
viper.GetString("mailer.smtp_host"),
),
&tls.Config{ServerName: viper.GetString("mailer.smtp_host")},
)
}

View File

@ -0,0 +1,18 @@
package services
import (
"git.solsynth.dev/hydrogen/passport/pkg/models"
"github.com/nicksnyder/go-i18n/v2/i18n"
)
func GetFactorName(w models.AuthFactorType, localizer *i18n.Localizer) string {
unknown, _ := localizer.LocalizeMessage(&i18n.Message{ID: "unknown"})
mfaEmail, _ := localizer.LocalizeMessage(&i18n.Message{ID: "mfaFactorEmail"})
switch w {
case models.EmailPasswordFactor:
return mfaEmail
default:
return unknown
}
}

View File

@ -0,0 +1,138 @@
package services
import (
"context"
"firebase.google.com/go/messaging"
"git.solsynth.dev/hydrogen/passport/pkg/database"
"git.solsynth.dev/hydrogen/passport/pkg/models"
"github.com/rs/zerolog/log"
"github.com/sideshow/apns2"
payload2 "github.com/sideshow/apns2/payload"
"github.com/spf13/viper"
"reflect"
)
func AddNotifySubscriber(user models.Account, provider, id, tk, ua string) (models.NotificationSubscriber, error) {
var prev models.NotificationSubscriber
var subscriber models.NotificationSubscriber
if err := database.C.Where(&models.NotificationSubscriber{
DeviceID: id,
AccountID: user.ID,
}); err != nil {
subscriber = models.NotificationSubscriber{
UserAgent: ua,
Provider: provider,
DeviceID: id,
DeviceToken: tk,
AccountID: user.ID,
}
} else {
prev = subscriber
}
subscriber.UserAgent = ua
subscriber.Provider = provider
subscriber.DeviceToken = tk
var err error
if !reflect.DeepEqual(subscriber, prev) {
err = database.C.Save(&subscriber).Error
}
return subscriber, err
}
func NewNotification(notification models.Notification) error {
if err := database.C.Save(&notification).Error; err != nil {
return err
}
go func() {
err := PushNotification(notification)
if err != nil {
log.Error().Err(err).Msg("Unexpected error occurred during the notification.")
}
}()
return nil
}
func PushNotification(notification models.Notification) error {
for conn := range wsConn[notification.RecipientID] {
_ = conn.WriteMessage(1, models.UnifiedCommand{
Action: "notifications.new",
Payload: notification,
}.Marshal())
}
// TODO Detect the push notification is turned off (still push when IsForcePush is on)
var subscribers []models.NotificationSubscriber
if err := database.C.Where(&models.NotificationSubscriber{
AccountID: notification.RecipientID,
}).Find(&subscribers).Error; err != nil {
return err
}
for _, subscriber := range subscribers {
switch subscriber.Provider {
case models.NotifySubscriberFirebase:
if ExtFire != nil {
ctx := context.Background()
client, err := ExtFire.Messaging(ctx)
if err != nil {
log.Warn().Err(err).Msg("An error occurred when creating FCM client...")
break
}
message := &messaging.Message{
Notification: &messaging.Notification{
Title: notification.Subject,
Body: notification.Content,
},
Token: subscriber.DeviceToken,
}
if response, err := client.Send(ctx, message); err != nil {
log.Warn().Err(err).Msg("An error occurred when notify subscriber via FCM...")
} else {
log.Debug().
Str("response", response).
Int("subscriber", int(subscriber.ID)).
Msg("Notified subscriber via FCM.")
}
}
case models.NotifySubscriberAPNs:
if ExtAPNS != nil {
data, err := payload2.
NewPayload().
AlertTitle(notification.Subject).
AlertBody(notification.Content).
Sound("default").
Category(notification.Type).
MarshalJSON()
if err != nil {
log.Warn().Err(err).Msg("An error occurred when preparing to notify subscriber via APNs...")
}
payload := &apns2.Notification{
ApnsID: subscriber.DeviceID,
DeviceToken: subscriber.DeviceToken,
Topic: viper.GetString("apns_topic"),
Payload: data,
}
if resp, err := ExtAPNS.Push(payload); err != nil {
log.Warn().Err(err).Msg("An error occurred when notify subscriber via APNs...")
} else {
log.Debug().
Str("reason", resp.Reason).
Int("status", resp.StatusCode).
Int("subscriber", int(subscriber.ID)).
Msg("Notified subscriber via APNs.")
}
}
}
}
return nil
}

View File

@ -0,0 +1,63 @@
package services
import (
"fmt"
"reflect"
"regexp"
"strings"
)
func HasPermNode(perms map[string]any, requiredKey string, requiredValue any) bool {
if heldValue, ok := perms[requiredKey]; ok {
return ComparePermNode(heldValue, requiredValue)
}
return false
}
func ComparePermNode(held any, required any) bool {
heldValue := reflect.ValueOf(held)
requiredValue := reflect.ValueOf(required)
switch heldValue.Kind() {
case reflect.Int, reflect.Float64:
if heldValue.Float() >= requiredValue.Float() {
return true
}
case reflect.String:
if heldValue.String() == requiredValue.String() {
return true
}
case reflect.Slice, reflect.Array:
for i := 0; i < heldValue.Len(); i++ {
if reflect.DeepEqual(heldValue.Index(i).Interface(), required) {
return true
}
}
default:
if reflect.DeepEqual(held, required) {
return true
}
}
return false
}
func FilterPermNodes(tree map[string]any, claims []string) map[string]any {
filteredTree := make(map[string]any)
match := func(claim, permission string) bool {
regex := strings.ReplaceAll(claim, "*", ".*")
match, _ := regexp.MatchString(fmt.Sprintf("^%s$", regex), permission)
return match
}
for _, claim := range claims {
for key, value := range tree {
if match(claim, key) {
filteredTree[key] = value
}
}
}
return filteredTree
}

View File

@ -0,0 +1,141 @@
package services
import (
"fmt"
"git.solsynth.dev/hydrogen/passport/pkg/database"
"git.solsynth.dev/hydrogen/passport/pkg/models"
"github.com/samber/lo"
)
func ListCommunityRealm() ([]models.Realm, error) {
var realms []models.Realm
if err := database.C.Where(&models.Realm{
IsCommunity: true,
}).Find(&realms).Error; err != nil {
return realms, err
}
return realms, nil
}
func ListOwnedRealm(user models.Account) ([]models.Realm, error) {
var realms []models.Realm
if err := database.C.Where(&models.Realm{AccountID: user.ID}).Find(&realms).Error; err != nil {
return realms, err
}
return realms, nil
}
func ListAvailableRealm(user models.Account) ([]models.Realm, error) {
var realms []models.Realm
var members []models.RealmMember
if err := database.C.Where(&models.RealmMember{
AccountID: user.ID,
}).Find(&members).Error; err != nil {
return realms, err
}
idx := lo.Map(members, func(item models.RealmMember, index int) uint {
return item.RealmID
})
if err := database.C.Where("id IN ?", idx).Find(&realms).Error; err != nil {
return realms, err
}
return realms, nil
}
func GetRealmWithAlias(alias string) (models.Realm, error) {
var realm models.Realm
if err := database.C.Where(&models.Realm{
Alias: alias,
}).First(&realm).Error; err != nil {
return realm, err
}
return realm, nil
}
func NewRealm(realm models.Realm, user models.Account) (models.Realm, error) {
realm.Members = []models.RealmMember{
{AccountID: user.ID, PowerLevel: 100},
}
err := database.C.Save(&realm).Error
return realm, err
}
func ListRealmMember(realmId uint) ([]models.RealmMember, error) {
var members []models.RealmMember
if err := database.C.
Where(&models.RealmMember{RealmID: realmId}).
Preload("Account").
Find(&members).Error; err != nil {
return members, err
}
return members, nil
}
func GetRealmMember(userId uint, realmId uint) (models.RealmMember, error) {
var member models.RealmMember
if err := database.C.Where(&models.RealmMember{
AccountID: userId,
RealmID: realmId,
}).Find(&member).Error; err != nil {
return member, err
}
return member, nil
}
func AddRealmMember(user models.Account, affected models.Account, target models.Realm) error {
if !target.IsPublic && !target.IsCommunity {
if member, err := GetRealmMember(user.ID, target.ID); err != nil {
return fmt.Errorf("only realm member can add people: %v", err)
} else if member.PowerLevel < 50 {
return fmt.Errorf("only realm moderator can add people")
}
friendship, err := GetFriendWithTwoSides(affected.ID, user.ID)
if err != nil || friendship.Status != models.FriendshipActive {
return fmt.Errorf("you only can add your friends to your realm")
}
}
member := models.RealmMember{
RealmID: target.ID,
AccountID: affected.ID,
}
err := database.C.Save(&member).Error
return err
}
func RemoveRealmMember(user models.Account, affected models.Account, target models.Realm) error {
if user.ID != affected.ID {
if member, err := GetRealmMember(user.ID, target.ID); err != nil {
return fmt.Errorf("only realm member can remove other member: %v", err)
} else if member.PowerLevel < 50 {
return fmt.Errorf("only realm moderator can invite people")
}
}
var member models.RealmMember
if err := database.C.Where(&models.RealmMember{
RealmID: target.ID,
AccountID: affected.ID,
}).First(&member).Error; err != nil {
return err
}
return database.C.Delete(&member).Error
}
func EditRealm(realm models.Realm) (models.Realm, error) {
err := database.C.Save(&realm).Error
return realm, err
}
func DeleteRealm(realm models.Realm) error {
return database.C.Delete(&realm).Error
}

View File

@ -0,0 +1,24 @@
package services
import (
"git.solsynth.dev/hydrogen/passport/pkg/database"
"git.solsynth.dev/hydrogen/passport/pkg/models"
"github.com/rs/zerolog/log"
"github.com/spf13/viper"
"time"
)
func DoAutoSignoff() {
duration := time.Duration(viper.GetInt64("security.auto_signoff_duration")) * time.Second
divider := time.Now().Add(-duration)
log.Debug().Time("before", divider).Msg("Now signing off tickets...")
if tx := database.C.
Where("last_grant_at < ?", divider).
Delete(&models.AuthTicket{}); tx.Error != nil {
log.Error().Err(tx.Error).Msg("An error occurred when running auto sign off...")
} else {
log.Debug().Int64("affected", tx.RowsAffected).Msg("Auto sign off accomplished.")
}
}

View File

@ -0,0 +1,158 @@
package services
import (
"fmt"
"time"
"github.com/google/uuid"
"git.solsynth.dev/hydrogen/passport/pkg/database"
"git.solsynth.dev/hydrogen/passport/pkg/models"
"github.com/samber/lo"
)
func DetectRisk(user models.Account, ip, ua string) bool {
var availableFactor int64
if err := database.C.
Where(models.AuthFactor{AccountID: user.ID}).
Where("type != ?", models.PasswordAuthFactor).
Model(models.AuthFactor{}).
Where(&availableFactor); err != nil || availableFactor <= 0 {
return false
}
var secureFactor int64
if err := database.C.
Where(models.AuthTicket{AccountID: user.ID, IpAddress: ip}).
Where("available_at IS NOT NULL").
Model(models.AuthTicket{}).
Count(&secureFactor).Error; err == nil {
if secureFactor >= 1 {
return false
}
}
return true
}
func NewTicket(user models.Account, ip, ua string) (models.AuthTicket, error) {
var ticket models.AuthTicket
if err := database.C.
Where("account_id = ? AND expired_at < ? AND available_at IS NULL", time.Now(), user.ID).
First(&ticket).Error; err == nil {
return ticket, nil
}
requireMFA := DetectRisk(user, ip, ua)
if count := CountUserFactor(user.ID); count <= 1 {
requireMFA = false
}
ticket = models.AuthTicket{
Claims: []string{"*"},
Audiences: []string{"passport"},
IpAddress: ip,
UserAgent: ua,
RequireMFA: requireMFA,
RequireAuthenticate: true,
ExpiredAt: nil,
AvailableAt: nil,
AccountID: user.ID,
}
err := database.C.Save(&ticket).Error
return ticket, err
}
func NewOauthTicket(
user models.Account,
client models.ThirdClient,
claims, audiences []string,
ip, ua string,
) (models.AuthTicket, error) {
ticket := models.AuthTicket{
Claims: claims,
Audiences: audiences,
IpAddress: ip,
UserAgent: ua,
RequireMFA: DetectRisk(user, ip, ua),
GrantToken: lo.ToPtr(uuid.NewString()),
AccessToken: lo.ToPtr(uuid.NewString()),
RefreshToken: lo.ToPtr(uuid.NewString()),
AvailableAt: lo.ToPtr(time.Now()),
ExpiredAt: lo.ToPtr(time.Now()),
ClientID: &client.ID,
AccountID: user.ID,
}
if err := database.C.Save(&ticket).Error; err != nil {
return ticket, err
}
return ticket, nil
}
func ActiveTicketWithPassword(ticket models.AuthTicket, password string) (models.AuthTicket, error) {
if ticket.AvailableAt != nil {
return ticket, nil
} else if !ticket.RequireAuthenticate {
return ticket, nil
}
if factor, err := GetPasswordTypeFactor(ticket.AccountID); err != nil {
return ticket, fmt.Errorf("unable to active ticket: %v", err)
} else if err = CheckFactor(factor, password); err != nil {
return ticket, err
}
ticket.RequireAuthenticate = false
if !ticket.RequireAuthenticate && !ticket.RequireMFA {
ticket.AvailableAt = lo.ToPtr(time.Now())
ticket.GrantToken = lo.ToPtr(uuid.NewString())
ticket.AccessToken = lo.ToPtr(uuid.NewString())
ticket.RefreshToken = lo.ToPtr(uuid.NewString())
}
if err := database.C.Save(&ticket).Error; err != nil {
return ticket, err
}
return ticket, nil
}
func ActiveTicketWithMFA(ticket models.AuthTicket, factor models.AuthFactor, code string) (models.AuthTicket, error) {
if ticket.AvailableAt != nil {
return ticket, nil
} else if !ticket.RequireMFA {
return ticket, nil
}
if err := CheckFactor(factor, code); err != nil {
return ticket, fmt.Errorf("invalid code: %v", err)
}
ticket.RequireMFA = false
if !ticket.RequireAuthenticate && !ticket.RequireMFA {
ticket.AvailableAt = lo.ToPtr(time.Now())
ticket.GrantToken = lo.ToPtr(uuid.NewString())
ticket.AccessToken = lo.ToPtr(uuid.NewString())
ticket.RefreshToken = lo.ToPtr(uuid.NewString())
}
if err := database.C.Save(&ticket).Error; err != nil {
return ticket, err
}
return ticket, nil
}
func RegenSession(ticket models.AuthTicket) (models.AuthTicket, error) {
ticket.GrantToken = lo.ToPtr(uuid.NewString())
ticket.AccessToken = lo.ToPtr(uuid.NewString())
ticket.RefreshToken = lo.ToPtr(uuid.NewString())
err := database.C.Save(&ticket).Error
return ticket, err
}

View File

@ -0,0 +1,29 @@
package services
import (
"git.solsynth.dev/hydrogen/passport/pkg/database"
"git.solsynth.dev/hydrogen/passport/pkg/models"
)
func GetTicket(id uint) (models.AuthTicket, error) {
var ticket models.AuthTicket
if err := database.C.
Where(&models.AuthTicket{BaseModel: models.BaseModel{ID: id}}).
First(&ticket).Error; err != nil {
return ticket, err
}
return ticket, nil
}
func GetTicketWithToken(tokenId string) (models.AuthTicket, error) {
var ticket models.AuthTicket
if err := database.C.
Where(models.AuthTicket{AccessToken: &tokenId}).
Or(models.AuthTicket{RefreshToken: &tokenId}).
First(&ticket).Error; err != nil {
return ticket, err
}
return ticket, nil
}

View File

@ -0,0 +1,98 @@
package services
import (
"fmt"
"git.solsynth.dev/hydrogen/passport/pkg/database"
"git.solsynth.dev/hydrogen/passport/pkg/models"
"github.com/samber/lo"
"github.com/spf13/viper"
"strconv"
"time"
)
func GetToken(ticket models.AuthTicket) (string, string, error) {
var refresh, access string
if err := ticket.IsAvailable(); err != nil {
return refresh, access, err
}
if ticket.AccessToken == nil || ticket.RefreshToken == nil {
return refresh, access, fmt.Errorf("unable to encode token, access or refresh token id missing")
}
accessDuration := time.Duration(viper.GetInt64("security.access_token_duration")) * time.Second
refreshDuration := time.Duration(viper.GetInt64("security.refresh_token_duration")) * time.Second
var err error
sub := strconv.Itoa(int(ticket.AccountID))
sed := strconv.Itoa(int(ticket.ID))
access, err = EncodeJwt(*ticket.AccessToken, JwtAccessType, sub, sed, ticket.Audiences, time.Now().Add(accessDuration))
if err != nil {
return refresh, access, err
}
refresh, err = EncodeJwt(*ticket.RefreshToken, JwtRefreshType, sub, sed, ticket.Audiences, time.Now().Add(refreshDuration))
if err != nil {
return refresh, access, err
}
ticket.LastGrantAt = lo.ToPtr(time.Now())
database.C.Save(&ticket)
return access, refresh, nil
}
func ExchangeToken(token string) (string, string, error) {
var ticket models.AuthTicket
if err := database.C.Where(models.AuthTicket{GrantToken: &token}).First(&ticket).Error; err != nil {
return "", "", err
} else if ticket.LastGrantAt != nil {
return "", "", fmt.Errorf("ticket was granted the first token, use refresh token instead")
} else if len(ticket.Audiences) > 1 {
return "", "", fmt.Errorf("should use authorization code grant type")
}
return GetToken(ticket)
}
func ExchangeOauthToken(clientId, clientSecret, redirectUri, token string) (string, string, error) {
var client models.ThirdClient
if err := database.C.Where(models.ThirdClient{Alias: clientId}).First(&client).Error; err != nil {
return "", "", err
} else if client.Secret != clientSecret {
return "", "", fmt.Errorf("invalid client secret")
} else if !client.IsDraft && !lo.Contains(client.Callbacks, redirectUri) {
return "", "", fmt.Errorf("invalid redirect uri")
}
var ticket models.AuthTicket
if err := database.C.Where(models.AuthTicket{GrantToken: &token}).First(&ticket).Error; err != nil {
return "", "", err
} else if ticket.LastGrantAt != nil {
return "", "", fmt.Errorf("ticket was granted the first token, use refresh token instead")
}
return GetToken(ticket)
}
func RefreshToken(token string) (string, string, error) {
parseInt := func(str string) int {
val, _ := strconv.Atoi(str)
return val
}
var ticket models.AuthTicket
if claims, err := DecodeJwt(token); err != nil {
return "404", "403", err
} else if claims.Type != JwtRefreshType {
return "404", "403", fmt.Errorf("invalid token type, expected refresh token")
} else if err := database.C.Where(models.AuthTicket{
BaseModel: models.BaseModel{ID: uint(parseInt(claims.SessionID))},
}).First(&ticket).Error; err != nil {
return "404", "403", err
}
if ticket, err := RegenSession(ticket); err != nil {
return "404", "403", err
} else {
return GetToken(ticket)
}
}

View File

@ -0,0 +1,92 @@
package services
import (
"fmt"
"strings"
"time"
"git.solsynth.dev/hydrogen/passport/pkg/database"
"git.solsynth.dev/hydrogen/passport/pkg/models"
"github.com/google/uuid"
"github.com/spf13/viper"
)
const ConfirmRegistrationTemplate = `Dear %s,
Thank you for choosing to register with %s. We are excited to welcome you to our community and appreciate your trust in us.
Your registration details have been successfully received, and you are now a valued member of %s. Here are the confirm link of your registration:
%s
As a confirmed registered member, you will have access to all our services.
We encourage you to explore our services and take full advantage of the resources available to you.
Once again, thank you for choosing us. We look forward to serving you and hope you have a positive experience with us.
Best regards,
%s`
func ValidateMagicToken(code string, mode models.MagicTokenType) (models.MagicToken, error) {
var tk models.MagicToken
if err := database.C.Where(models.MagicToken{Code: code, Type: mode}).First(&tk).Error; err != nil {
return tk, err
} else if tk.ExpiredAt != nil && time.Now().Unix() >= tk.ExpiredAt.Unix() {
return tk, fmt.Errorf("token has been expired")
}
return tk, nil
}
func NewMagicToken(mode models.MagicTokenType, assignTo *models.Account, expiredAt *time.Time) (models.MagicToken, error) {
var uid uint
if assignTo != nil {
uid = assignTo.ID
}
token := models.MagicToken{
Code: strings.Replace(uuid.NewString(), "-", "", -1),
Type: mode,
AssignTo: &uid,
ExpiredAt: expiredAt,
}
if err := database.C.Save(&token).Error; err != nil {
return token, err
} else {
return token, nil
}
}
func NotifyMagicToken(token models.MagicToken) error {
if token.AssignTo == nil {
return fmt.Errorf("could notify a non-assign magic token")
}
var user models.Account
if err := database.C.Where(&models.Account{
BaseModel: models.BaseModel{ID: *token.AssignTo},
}).Preload("Contacts").First(&user).Error; err != nil {
return err
}
var subject string
var content string
switch token.Type {
case models.ConfirmMagicToken:
link := fmt.Sprintf("https://%s/me/confirm?tk=%s", viper.GetString("domain"), token.Code)
subject = fmt.Sprintf("[%s] Confirm your registration", viper.GetString("name"))
content = fmt.Sprintf(
ConfirmRegistrationTemplate,
user.Name,
viper.GetString("name"),
viper.GetString("maintainer"),
link,
viper.GetString("maintainer"),
)
default:
return fmt.Errorf("unsupported magic token type to notify")
}
return SendMail(user.GetPrimaryEmail().Content, subject, content)
}