⚡ Add a cache layer in auth to speed up auth
This commit is contained in:
parent
211959167a
commit
b69ac44885
1
go.mod
1
go.mod
@ -70,6 +70,7 @@ require (
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasthttp v1.51.0 // indirect
|
||||
github.com/valyala/tcplisten v1.0.0 // indirect
|
||||
go.etcd.io/bbolt v1.3.9 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 // indirect
|
||||
|
2
go.sum
2
go.sum
@ -198,6 +198,8 @@ github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdI
|
||||
github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8=
|
||||
github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
go.etcd.io/bbolt v1.3.9 h1:8x7aARPEXiXbHmtUwAIv7eV2fQFHrLLavdiJ3uzJXoI=
|
||||
go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE=
|
||||
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
|
||||
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
|
@ -36,11 +36,14 @@ func main() {
|
||||
}
|
||||
|
||||
// Connect to database
|
||||
if err := database.NewSource(); err != nil {
|
||||
if err := database.NewGorm(); err != nil {
|
||||
log.Fatal().Err(err).Msg("An error occurred when connect to database.")
|
||||
} else if err := database.RunMigration(database.C); err != nil {
|
||||
log.Fatal().Err(err).Msg("An error occurred when running database auto migration.")
|
||||
}
|
||||
if err := database.NewBolt(); err != nil {
|
||||
log.Fatal().Err(err).Msg("An error occurred when init bolt db.")
|
||||
}
|
||||
|
||||
// External
|
||||
// All the things are optional so when error occurred the server won't crash
|
||||
@ -83,4 +86,6 @@ func main() {
|
||||
log.Info().Msgf("Identity v%s is quitting...", identity.AppVersion)
|
||||
|
||||
quartz.Stop()
|
||||
|
||||
database.B.Close()
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
"github.com/spf13/viper"
|
||||
"go.etcd.io/bbolt"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
@ -12,7 +13,7 @@ import (
|
||||
|
||||
var C *gorm.DB
|
||||
|
||||
func NewSource() error {
|
||||
func NewGorm() error {
|
||||
var err error
|
||||
|
||||
dialector := postgres.Open(viper.GetString("database.dsn"))
|
||||
@ -21,8 +22,19 @@ func NewSource() error {
|
||||
}, Logger: logger.New(&log.Logger, logger.Config{
|
||||
Colorful: true,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
LogLevel: lo.Ternary(viper.GetBool("debug"), logger.Info, logger.Silent),
|
||||
LogLevel: lo.Ternary(viper.GetBool("debug.database"), logger.Info, logger.Silent),
|
||||
})})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
var B *bbolt.DB
|
||||
|
||||
func NewBolt() error {
|
||||
var err error
|
||||
|
||||
dsn := viper.GetString("database.bolt")
|
||||
B, err = bbolt.Open(dsn, 0600, nil)
|
||||
|
||||
return err
|
||||
}
|
||||
|
34
pkg/grpc/auth.go
Normal file
34
pkg/grpc/auth.go
Normal file
@ -0,0 +1,34 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"git.solsynth.dev/hydrogen/identity/pkg/grpc/proto"
|
||||
"git.solsynth.dev/hydrogen/identity/pkg/services"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func (v *Server) Authenticate(_ context.Context, in *proto.AuthRequest) (*proto.AuthReply, error) {
|
||||
user, atk, rtk, err := services.Authenticate(in.GetAccessToken(), in.GetRefreshToken(), 0)
|
||||
if err != nil {
|
||||
return &proto.AuthReply{
|
||||
IsValid: false,
|
||||
}, nil
|
||||
} else {
|
||||
return &proto.AuthReply{
|
||||
IsValid: true,
|
||||
AccessToken: &atk,
|
||||
RefreshToken: &rtk,
|
||||
Userinfo: &proto.Userinfo{
|
||||
Id: uint64(user.ID),
|
||||
Name: user.Name,
|
||||
Nick: user.Nick,
|
||||
Email: user.GetPrimaryEmail().Content,
|
||||
Avatar: fmt.Sprintf("https://%s/api/avatar/%s", viper.GetString("domain"), user.Avatar),
|
||||
Banner: fmt.Sprintf("https://%s/api/avatar/%s", viper.GetString("domain"), user.Banner),
|
||||
Description: &user.Description,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
35
pkg/grpc/notify.go
Normal file
35
pkg/grpc/notify.go
Normal file
@ -0,0 +1,35 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.solsynth.dev/hydrogen/identity/pkg/grpc/proto"
|
||||
"git.solsynth.dev/hydrogen/identity/pkg/models"
|
||||
"git.solsynth.dev/hydrogen/identity/pkg/services"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
func (v *Server) NotifyUser(_ context.Context, in *proto.NotifyRequest) (*proto.NotifyReply, error) {
|
||||
client, err := services.GetThirdClientWithSecret(in.GetClientId(), in.GetClientSecret())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var user models.Account
|
||||
if user, err = services.GetAccount(uint(in.GetRecipientId())); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
links := lo.Map(in.GetLinks(), func(item *proto.NotifyLink, index int) models.NotificationLink {
|
||||
return models.NotificationLink{
|
||||
Label: item.Label,
|
||||
Url: item.Url,
|
||||
}
|
||||
})
|
||||
|
||||
if err := services.NewNotification(client, user, in.Subject, in.Content, links, in.IsImportant); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &proto.NotifyReply{IsSent: true}, nil
|
||||
}
|
@ -1,14 +1,9 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"git.solsynth.dev/hydrogen/identity/pkg/grpc/proto"
|
||||
"git.solsynth.dev/hydrogen/identity/pkg/models"
|
||||
"git.solsynth.dev/hydrogen/identity/pkg/services"
|
||||
"github.com/samber/lo"
|
||||
"github.com/spf13/viper"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/reflection"
|
||||
@ -19,55 +14,6 @@ type Server struct {
|
||||
proto.UnimplementedNotifyServer
|
||||
}
|
||||
|
||||
func (v *Server) Authenticate(_ context.Context, in *proto.AuthRequest) (*proto.AuthReply, error) {
|
||||
user, atk, rtk, err := services.Authenticate(in.GetAccessToken(), in.GetRefreshToken(), 0)
|
||||
if err != nil {
|
||||
return &proto.AuthReply{
|
||||
IsValid: false,
|
||||
}, nil
|
||||
} else {
|
||||
return &proto.AuthReply{
|
||||
IsValid: true,
|
||||
AccessToken: &atk,
|
||||
RefreshToken: &rtk,
|
||||
Userinfo: &proto.Userinfo{
|
||||
Id: uint64(user.ID),
|
||||
Name: user.Name,
|
||||
Nick: user.Nick,
|
||||
Email: user.GetPrimaryEmail().Content,
|
||||
Avatar: fmt.Sprintf("https://%s/api/avatar/%s", viper.GetString("domain"), user.Avatar),
|
||||
Banner: fmt.Sprintf("https://%s/api/avatar/%s", viper.GetString("domain"), user.Banner),
|
||||
Description: &user.Description,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (v *Server) NotifyUser(_ context.Context, in *proto.NotifyRequest) (*proto.NotifyReply, error) {
|
||||
client, err := services.GetThirdClientWithSecret(in.GetClientId(), in.GetClientSecret())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var user models.Account
|
||||
if user, err = services.GetAccount(uint(in.GetRecipientId())); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
links := lo.Map(in.GetLinks(), func(item *proto.NotifyLink, index int) models.NotificationLink {
|
||||
return models.NotificationLink{
|
||||
Label: item.Label,
|
||||
Url: item.Url,
|
||||
}
|
||||
})
|
||||
|
||||
if err := services.NewNotification(client, user, in.Subject, in.Content, links, in.IsImportant); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &proto.NotifyReply{IsSent: true}, nil
|
||||
}
|
||||
|
||||
func StartGrpc() error {
|
||||
listen, err := net.Listen("tcp", viper.GetString("grpc_bind"))
|
||||
if err != nil {
|
||||
|
@ -81,3 +81,9 @@ func (v AuthChallenge) IsAvailable() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type AuthContext struct {
|
||||
Session AuthSession `json:"session"`
|
||||
Account Account `json:"account"`
|
||||
ExpiredAt time.Time `json:"expired_at"`
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ func NewServer() {
|
||||
ProxyHeader: fiber.HeaderXForwardedFor,
|
||||
JSONEncoder: jsoniter.ConfigCompatibleWithStandardLibrary.Marshal,
|
||||
JSONDecoder: jsoniter.ConfigCompatibleWithStandardLibrary.Unmarshal,
|
||||
EnablePrintRoutes: viper.GetBool("debug"),
|
||||
EnablePrintRoutes: viper.GetBool("debug.print_routes"),
|
||||
})
|
||||
|
||||
A.Use(idempotency.New())
|
||||
|
@ -2,12 +2,19 @@ package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.solsynth.dev/hydrogen/identity/pkg/database"
|
||||
"git.solsynth.dev/hydrogen/identity/pkg/models"
|
||||
"git.solsynth.dev/hydrogen/identity/pkg/security"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
jsoniter "github.com/json-iterator/go"
|
||||
"github.com/rs/zerolog/log"
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
const authContextBucket = "AuthContext"
|
||||
|
||||
func Authenticate(access, refresh string, depth int) (models.Account, string, string, error) {
|
||||
var user models.Account
|
||||
claims, err := security.DecodeJwt(access)
|
||||
@ -22,17 +29,99 @@ func Authenticate(access, refresh string, depth int) (models.Account, string, st
|
||||
return user, access, refresh, fiber.NewError(fiber.StatusUnauthorized, fmt.Sprintf("invalid auth key: %v", err))
|
||||
}
|
||||
|
||||
session, err := LookupSessionWithToken(claims.ID)
|
||||
if err != nil {
|
||||
return user, access, refresh, fiber.NewError(fiber.StatusUnauthorized, fmt.Sprintf("invalid auth session: %v", err))
|
||||
} else if err := session.IsAvailable(); err != nil {
|
||||
return user, access, refresh, fiber.NewError(fiber.StatusUnauthorized, fmt.Sprintf("unavailable auth session: %v", err))
|
||||
var ctx models.AuthContext
|
||||
|
||||
ctx, lookupErr := GetAuthContext(claims.ID)
|
||||
if lookupErr == nil {
|
||||
log.Debug().Str("jti", claims.ID).Msg("Hit auth context cache once!")
|
||||
return ctx.Account, access, refresh, nil
|
||||
}
|
||||
|
||||
user, err = GetAccount(session.AccountID)
|
||||
if err != nil {
|
||||
return user, access, refresh, fiber.NewError(fiber.StatusUnauthorized, fmt.Sprintf("invalid account: %v", err))
|
||||
ctx, err = GrantAuthContext(claims.ID)
|
||||
if err == nil {
|
||||
log.Debug().Str("jti", claims.ID).Err(lookupErr).Msg("Missed auth context cache once!")
|
||||
return user, access, refresh, nil
|
||||
}
|
||||
|
||||
return user, access, refresh, nil
|
||||
return user, access, refresh, fiber.NewError(fiber.StatusUnauthorized, err.Error())
|
||||
}
|
||||
|
||||
func GetAuthContext(jti string) (models.AuthContext, error) {
|
||||
var err error
|
||||
var ctx models.AuthContext
|
||||
|
||||
err = database.B.View(func(tx *bbolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte(authContextBucket))
|
||||
if bucket == nil {
|
||||
return fmt.Errorf("unable to find auth context bucket")
|
||||
}
|
||||
|
||||
raw := bucket.Get([]byte(jti))
|
||||
if raw == nil {
|
||||
return fmt.Errorf("unable to find auth context")
|
||||
} else if err := jsoniter.Unmarshal(raw, &ctx); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal auth context: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err == nil && time.Now().Unix() >= ctx.ExpiredAt.Unix() {
|
||||
RevokeAuthContext(jti)
|
||||
|
||||
return ctx, fmt.Errorf("auth context has been expired")
|
||||
}
|
||||
|
||||
return ctx, err
|
||||
}
|
||||
|
||||
func GrantAuthContext(jti string) (models.AuthContext, error) {
|
||||
var ctx models.AuthContext
|
||||
|
||||
// Query data from primary database
|
||||
session, err := LookupSessionWithToken(jti)
|
||||
if err != nil {
|
||||
return ctx, fmt.Errorf("invalid auth session: %v", err)
|
||||
} else if err := session.IsAvailable(); err != nil {
|
||||
return ctx, fmt.Errorf("unavailable auth session: %v", err)
|
||||
}
|
||||
|
||||
user, err := GetAccount(session.AccountID)
|
||||
if err != nil {
|
||||
return ctx, fmt.Errorf("invalid account: %v", err)
|
||||
}
|
||||
|
||||
// Every context should expires in some while
|
||||
// Once user update their account info, this will have delay to update
|
||||
ctx = models.AuthContext{
|
||||
Session: session,
|
||||
Account: user,
|
||||
ExpiredAt: time.Now().Add(5 * time.Minute),
|
||||
}
|
||||
|
||||
// Save data into KV cache
|
||||
return ctx, database.B.Update(func(tx *bbolt.Tx) error {
|
||||
bucket, err := tx.CreateBucketIfNotExists([]byte(authContextBucket))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
raw, err := jsoniter.Marshal(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bucket.Put([]byte(jti), raw)
|
||||
})
|
||||
}
|
||||
|
||||
func RevokeAuthContext(jti string) error {
|
||||
return database.B.Update(func(tx *bbolt.Tx) error {
|
||||
bucket, err := tx.CreateBucketIfNotExists([]byte(authContextBucket))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bucket.Delete([]byte(jti))
|
||||
})
|
||||
}
|
||||
|
@ -1,5 +1,3 @@
|
||||
debug = true
|
||||
|
||||
name = "Goatpass"
|
||||
maintainer = "SmartSheep Studio"
|
||||
|
||||
@ -12,6 +10,10 @@ content = "uploads"
|
||||
|
||||
use_registration_magic_token = false
|
||||
|
||||
[debug]
|
||||
database = false
|
||||
print_routes = false
|
||||
|
||||
[external.firebase]
|
||||
credentials = "dist/firebase-certs.json"
|
||||
|
||||
@ -32,3 +34,4 @@ refresh_token_duration = 2592000
|
||||
[database]
|
||||
dsn = "host=localhost dbname=hy_identity port=5432 sslmode=disable"
|
||||
prefix = "identity_"
|
||||
bolt = "uploads/bolt.db"
|
||||
|
Loading…
Reference in New Issue
Block a user