Add a cache layer in auth to speed up auth

This commit is contained in:
LittleSheep 2024-03-23 00:28:27 +08:00
parent 211959167a
commit b69ac44885
11 changed files with 202 additions and 69 deletions

1
go.mod
View File

@ -70,6 +70,7 @@ require (
github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.51.0 // indirect github.com/valyala/fasthttp v1.51.0 // indirect
github.com/valyala/tcplisten v1.0.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect
go.etcd.io/bbolt v1.3.9 // indirect
go.opencensus.io v0.24.0 // indirect go.opencensus.io v0.24.0 // indirect
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 // indirect golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 // indirect

2
go.sum
View File

@ -198,6 +198,8 @@ github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdI
github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8=
github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.etcd.io/bbolt v1.3.9 h1:8x7aARPEXiXbHmtUwAIv7eV2fQFHrLLavdiJ3uzJXoI=
go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE=
go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=

View File

@ -36,11 +36,14 @@ func main() {
} }
// Connect to database // Connect to database
if err := database.NewSource(); err != nil { if err := database.NewGorm(); err != nil {
log.Fatal().Err(err).Msg("An error occurred when connect to database.") log.Fatal().Err(err).Msg("An error occurred when connect to database.")
} else if err := database.RunMigration(database.C); err != nil { } else if err := database.RunMigration(database.C); err != nil {
log.Fatal().Err(err).Msg("An error occurred when running database auto migration.") log.Fatal().Err(err).Msg("An error occurred when running database auto migration.")
} }
if err := database.NewBolt(); err != nil {
log.Fatal().Err(err).Msg("An error occurred when init bolt db.")
}
// External // External
// All the things are optional so when error occurred the server won't crash // All the things are optional so when error occurred the server won't crash
@ -83,4 +86,6 @@ func main() {
log.Info().Msgf("Identity v%s is quitting...", identity.AppVersion) log.Info().Msgf("Identity v%s is quitting...", identity.AppVersion)
quartz.Stop() quartz.Stop()
database.B.Close()
} }

View File

@ -4,6 +4,7 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/spf13/viper" "github.com/spf13/viper"
"go.etcd.io/bbolt"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
@ -12,7 +13,7 @@ import (
var C *gorm.DB var C *gorm.DB
func NewSource() error { func NewGorm() error {
var err error var err error
dialector := postgres.Open(viper.GetString("database.dsn")) dialector := postgres.Open(viper.GetString("database.dsn"))
@ -21,8 +22,19 @@ func NewSource() error {
}, Logger: logger.New(&log.Logger, logger.Config{ }, Logger: logger.New(&log.Logger, logger.Config{
Colorful: true, Colorful: true,
IgnoreRecordNotFoundError: true, IgnoreRecordNotFoundError: true,
LogLevel: lo.Ternary(viper.GetBool("debug"), logger.Info, logger.Silent), LogLevel: lo.Ternary(viper.GetBool("debug.database"), logger.Info, logger.Silent),
})}) })})
return err return err
} }
var B *bbolt.DB
func NewBolt() error {
var err error
dsn := viper.GetString("database.bolt")
B, err = bbolt.Open(dsn, 0600, nil)
return err
}

34
pkg/grpc/auth.go Normal file
View File

@ -0,0 +1,34 @@
package grpc
import (
"context"
"fmt"
"git.solsynth.dev/hydrogen/identity/pkg/grpc/proto"
"git.solsynth.dev/hydrogen/identity/pkg/services"
"github.com/spf13/viper"
)
func (v *Server) Authenticate(_ context.Context, in *proto.AuthRequest) (*proto.AuthReply, error) {
user, atk, rtk, err := services.Authenticate(in.GetAccessToken(), in.GetRefreshToken(), 0)
if err != nil {
return &proto.AuthReply{
IsValid: false,
}, nil
} else {
return &proto.AuthReply{
IsValid: true,
AccessToken: &atk,
RefreshToken: &rtk,
Userinfo: &proto.Userinfo{
Id: uint64(user.ID),
Name: user.Name,
Nick: user.Nick,
Email: user.GetPrimaryEmail().Content,
Avatar: fmt.Sprintf("https://%s/api/avatar/%s", viper.GetString("domain"), user.Avatar),
Banner: fmt.Sprintf("https://%s/api/avatar/%s", viper.GetString("domain"), user.Banner),
Description: &user.Description,
},
}, nil
}
}

35
pkg/grpc/notify.go Normal file
View File

@ -0,0 +1,35 @@
package grpc
import (
"context"
"git.solsynth.dev/hydrogen/identity/pkg/grpc/proto"
"git.solsynth.dev/hydrogen/identity/pkg/models"
"git.solsynth.dev/hydrogen/identity/pkg/services"
"github.com/samber/lo"
)
func (v *Server) NotifyUser(_ context.Context, in *proto.NotifyRequest) (*proto.NotifyReply, error) {
client, err := services.GetThirdClientWithSecret(in.GetClientId(), in.GetClientSecret())
if err != nil {
return nil, err
}
var user models.Account
if user, err = services.GetAccount(uint(in.GetRecipientId())); err != nil {
return nil, err
}
links := lo.Map(in.GetLinks(), func(item *proto.NotifyLink, index int) models.NotificationLink {
return models.NotificationLink{
Label: item.Label,
Url: item.Url,
}
})
if err := services.NewNotification(client, user, in.Subject, in.Content, links, in.IsImportant); err != nil {
return nil, err
}
return &proto.NotifyReply{IsSent: true}, nil
}

View File

@ -1,14 +1,9 @@
package grpc package grpc
import ( import (
"context"
"fmt"
"net" "net"
"git.solsynth.dev/hydrogen/identity/pkg/grpc/proto" "git.solsynth.dev/hydrogen/identity/pkg/grpc/proto"
"git.solsynth.dev/hydrogen/identity/pkg/models"
"git.solsynth.dev/hydrogen/identity/pkg/services"
"github.com/samber/lo"
"github.com/spf13/viper" "github.com/spf13/viper"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/reflection" "google.golang.org/grpc/reflection"
@ -19,55 +14,6 @@ type Server struct {
proto.UnimplementedNotifyServer proto.UnimplementedNotifyServer
} }
func (v *Server) Authenticate(_ context.Context, in *proto.AuthRequest) (*proto.AuthReply, error) {
user, atk, rtk, err := services.Authenticate(in.GetAccessToken(), in.GetRefreshToken(), 0)
if err != nil {
return &proto.AuthReply{
IsValid: false,
}, nil
} else {
return &proto.AuthReply{
IsValid: true,
AccessToken: &atk,
RefreshToken: &rtk,
Userinfo: &proto.Userinfo{
Id: uint64(user.ID),
Name: user.Name,
Nick: user.Nick,
Email: user.GetPrimaryEmail().Content,
Avatar: fmt.Sprintf("https://%s/api/avatar/%s", viper.GetString("domain"), user.Avatar),
Banner: fmt.Sprintf("https://%s/api/avatar/%s", viper.GetString("domain"), user.Banner),
Description: &user.Description,
},
}, nil
}
}
func (v *Server) NotifyUser(_ context.Context, in *proto.NotifyRequest) (*proto.NotifyReply, error) {
client, err := services.GetThirdClientWithSecret(in.GetClientId(), in.GetClientSecret())
if err != nil {
return nil, err
}
var user models.Account
if user, err = services.GetAccount(uint(in.GetRecipientId())); err != nil {
return nil, err
}
links := lo.Map(in.GetLinks(), func(item *proto.NotifyLink, index int) models.NotificationLink {
return models.NotificationLink{
Label: item.Label,
Url: item.Url,
}
})
if err := services.NewNotification(client, user, in.Subject, in.Content, links, in.IsImportant); err != nil {
return nil, err
}
return &proto.NotifyReply{IsSent: true}, nil
}
func StartGrpc() error { func StartGrpc() error {
listen, err := net.Listen("tcp", viper.GetString("grpc_bind")) listen, err := net.Listen("tcp", viper.GetString("grpc_bind"))
if err != nil { if err != nil {

View File

@ -81,3 +81,9 @@ func (v AuthChallenge) IsAvailable() error {
return nil return nil
} }
type AuthContext struct {
Session AuthSession `json:"session"`
Account Account `json:"account"`
ExpiredAt time.Time `json:"expired_at"`
}

View File

@ -28,7 +28,7 @@ func NewServer() {
ProxyHeader: fiber.HeaderXForwardedFor, ProxyHeader: fiber.HeaderXForwardedFor,
JSONEncoder: jsoniter.ConfigCompatibleWithStandardLibrary.Marshal, JSONEncoder: jsoniter.ConfigCompatibleWithStandardLibrary.Marshal,
JSONDecoder: jsoniter.ConfigCompatibleWithStandardLibrary.Unmarshal, JSONDecoder: jsoniter.ConfigCompatibleWithStandardLibrary.Unmarshal,
EnablePrintRoutes: viper.GetBool("debug"), EnablePrintRoutes: viper.GetBool("debug.print_routes"),
}) })
A.Use(idempotency.New()) A.Use(idempotency.New())

View File

@ -2,12 +2,19 @@ package services
import ( import (
"fmt" "fmt"
"time"
"git.solsynth.dev/hydrogen/identity/pkg/database"
"git.solsynth.dev/hydrogen/identity/pkg/models" "git.solsynth.dev/hydrogen/identity/pkg/models"
"git.solsynth.dev/hydrogen/identity/pkg/security" "git.solsynth.dev/hydrogen/identity/pkg/security"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
jsoniter "github.com/json-iterator/go"
"github.com/rs/zerolog/log"
"go.etcd.io/bbolt"
) )
const authContextBucket = "AuthContext"
func Authenticate(access, refresh string, depth int) (models.Account, string, string, error) { func Authenticate(access, refresh string, depth int) (models.Account, string, string, error) {
var user models.Account var user models.Account
claims, err := security.DecodeJwt(access) claims, err := security.DecodeJwt(access)
@ -22,17 +29,99 @@ func Authenticate(access, refresh string, depth int) (models.Account, string, st
return user, access, refresh, fiber.NewError(fiber.StatusUnauthorized, fmt.Sprintf("invalid auth key: %v", err)) return user, access, refresh, fiber.NewError(fiber.StatusUnauthorized, fmt.Sprintf("invalid auth key: %v", err))
} }
session, err := LookupSessionWithToken(claims.ID) var ctx models.AuthContext
if err != nil {
return user, access, refresh, fiber.NewError(fiber.StatusUnauthorized, fmt.Sprintf("invalid auth session: %v", err)) ctx, lookupErr := GetAuthContext(claims.ID)
} else if err := session.IsAvailable(); err != nil { if lookupErr == nil {
return user, access, refresh, fiber.NewError(fiber.StatusUnauthorized, fmt.Sprintf("unavailable auth session: %v", err)) log.Debug().Str("jti", claims.ID).Msg("Hit auth context cache once!")
} return ctx.Account, access, refresh, nil
user, err = GetAccount(session.AccountID)
if err != nil {
return user, access, refresh, fiber.NewError(fiber.StatusUnauthorized, fmt.Sprintf("invalid account: %v", err))
} }
ctx, err = GrantAuthContext(claims.ID)
if err == nil {
log.Debug().Str("jti", claims.ID).Err(lookupErr).Msg("Missed auth context cache once!")
return user, access, refresh, nil return user, access, refresh, nil
}
return user, access, refresh, fiber.NewError(fiber.StatusUnauthorized, err.Error())
}
func GetAuthContext(jti string) (models.AuthContext, error) {
var err error
var ctx models.AuthContext
err = database.B.View(func(tx *bbolt.Tx) error {
bucket := tx.Bucket([]byte(authContextBucket))
if bucket == nil {
return fmt.Errorf("unable to find auth context bucket")
}
raw := bucket.Get([]byte(jti))
if raw == nil {
return fmt.Errorf("unable to find auth context")
} else if err := jsoniter.Unmarshal(raw, &ctx); err != nil {
return fmt.Errorf("unable to unmarshal auth context: %v", err)
}
return nil
})
if err == nil && time.Now().Unix() >= ctx.ExpiredAt.Unix() {
RevokeAuthContext(jti)
return ctx, fmt.Errorf("auth context has been expired")
}
return ctx, err
}
func GrantAuthContext(jti string) (models.AuthContext, error) {
var ctx models.AuthContext
// Query data from primary database
session, err := LookupSessionWithToken(jti)
if err != nil {
return ctx, fmt.Errorf("invalid auth session: %v", err)
} else if err := session.IsAvailable(); err != nil {
return ctx, fmt.Errorf("unavailable auth session: %v", err)
}
user, err := GetAccount(session.AccountID)
if err != nil {
return ctx, fmt.Errorf("invalid account: %v", err)
}
// Every context should expires in some while
// Once user update their account info, this will have delay to update
ctx = models.AuthContext{
Session: session,
Account: user,
ExpiredAt: time.Now().Add(5 * time.Minute),
}
// Save data into KV cache
return ctx, database.B.Update(func(tx *bbolt.Tx) error {
bucket, err := tx.CreateBucketIfNotExists([]byte(authContextBucket))
if err != nil {
return err
}
raw, err := jsoniter.Marshal(ctx)
if err != nil {
return err
}
return bucket.Put([]byte(jti), raw)
})
}
func RevokeAuthContext(jti string) error {
return database.B.Update(func(tx *bbolt.Tx) error {
bucket, err := tx.CreateBucketIfNotExists([]byte(authContextBucket))
if err != nil {
return err
}
return bucket.Delete([]byte(jti))
})
} }

View File

@ -1,5 +1,3 @@
debug = true
name = "Goatpass" name = "Goatpass"
maintainer = "SmartSheep Studio" maintainer = "SmartSheep Studio"
@ -12,6 +10,10 @@ content = "uploads"
use_registration_magic_token = false use_registration_magic_token = false
[debug]
database = false
print_routes = false
[external.firebase] [external.firebase]
credentials = "dist/firebase-certs.json" credentials = "dist/firebase-certs.json"
@ -32,3 +34,4 @@ refresh_token_duration = 2592000
[database] [database]
dsn = "host=localhost dbname=hy_identity port=5432 sslmode=disable" dsn = "host=localhost dbname=hy_identity port=5432 sslmode=disable"
prefix = "identity_" prefix = "identity_"
bolt = "uploads/bolt.db"