diff --git a/go.mod b/go.mod index ebe2894..0f5108e 100644 --- a/go.mod +++ b/go.mod @@ -70,6 +70,7 @@ require ( github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.51.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect + go.etcd.io/bbolt v1.3.9 // indirect go.opencensus.io v0.24.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 // indirect diff --git a/go.sum b/go.sum index ff7dace..12d14c9 100644 --- a/go.sum +++ b/go.sum @@ -198,6 +198,8 @@ github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdI github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.etcd.io/bbolt v1.3.9 h1:8x7aARPEXiXbHmtUwAIv7eV2fQFHrLLavdiJ3uzJXoI= +go.etcd.io/bbolt v1.3.9/go.mod h1:zaO32+Ti0PK1ivdPtgMESzuzL2VPoIG1PCQNvOdo/dE= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= diff --git a/pkg/cmd/main.go b/pkg/cmd/main.go index 9135e27..d317fa7 100644 --- a/pkg/cmd/main.go +++ b/pkg/cmd/main.go @@ -36,11 +36,14 @@ func main() { } // Connect to database - if err := database.NewSource(); err != nil { + if err := database.NewGorm(); err != nil { log.Fatal().Err(err).Msg("An error occurred when connect to database.") } else if err := database.RunMigration(database.C); err != nil { log.Fatal().Err(err).Msg("An error occurred when running database auto migration.") } + if err := database.NewBolt(); err != nil { + log.Fatal().Err(err).Msg("An error occurred when init bolt db.") + } // External // All the things are optional so when error occurred the server won't crash @@ -83,4 +86,6 @@ func main() { log.Info().Msgf("Identity v%s is quitting...", identity.AppVersion) quartz.Stop() + + database.B.Close() } diff --git a/pkg/database/source.go b/pkg/database/source.go index 5fdb79d..db76666 100644 --- a/pkg/database/source.go +++ b/pkg/database/source.go @@ -4,6 +4,7 @@ import ( "github.com/rs/zerolog/log" "github.com/samber/lo" "github.com/spf13/viper" + "go.etcd.io/bbolt" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" @@ -12,7 +13,7 @@ import ( var C *gorm.DB -func NewSource() error { +func NewGorm() error { var err error dialector := postgres.Open(viper.GetString("database.dsn")) @@ -21,8 +22,19 @@ func NewSource() error { }, Logger: logger.New(&log.Logger, logger.Config{ Colorful: true, IgnoreRecordNotFoundError: true, - LogLevel: lo.Ternary(viper.GetBool("debug"), logger.Info, logger.Silent), + LogLevel: lo.Ternary(viper.GetBool("debug.database"), logger.Info, logger.Silent), })}) return err } + +var B *bbolt.DB + +func NewBolt() error { + var err error + + dsn := viper.GetString("database.bolt") + B, err = bbolt.Open(dsn, 0600, nil) + + return err +} diff --git a/pkg/grpc/auth.go b/pkg/grpc/auth.go new file mode 100644 index 0000000..a7f3a2c --- /dev/null +++ b/pkg/grpc/auth.go @@ -0,0 +1,34 @@ +package grpc + +import ( + "context" + "fmt" + + "git.solsynth.dev/hydrogen/identity/pkg/grpc/proto" + "git.solsynth.dev/hydrogen/identity/pkg/services" + "github.com/spf13/viper" +) + +func (v *Server) Authenticate(_ context.Context, in *proto.AuthRequest) (*proto.AuthReply, error) { + user, atk, rtk, err := services.Authenticate(in.GetAccessToken(), in.GetRefreshToken(), 0) + if err != nil { + return &proto.AuthReply{ + IsValid: false, + }, nil + } else { + return &proto.AuthReply{ + IsValid: true, + AccessToken: &atk, + RefreshToken: &rtk, + Userinfo: &proto.Userinfo{ + Id: uint64(user.ID), + Name: user.Name, + Nick: user.Nick, + Email: user.GetPrimaryEmail().Content, + Avatar: fmt.Sprintf("https://%s/api/avatar/%s", viper.GetString("domain"), user.Avatar), + Banner: fmt.Sprintf("https://%s/api/avatar/%s", viper.GetString("domain"), user.Banner), + Description: &user.Description, + }, + }, nil + } +} diff --git a/pkg/grpc/notify.go b/pkg/grpc/notify.go new file mode 100644 index 0000000..3e449bb --- /dev/null +++ b/pkg/grpc/notify.go @@ -0,0 +1,35 @@ +package grpc + +import ( + "context" + + "git.solsynth.dev/hydrogen/identity/pkg/grpc/proto" + "git.solsynth.dev/hydrogen/identity/pkg/models" + "git.solsynth.dev/hydrogen/identity/pkg/services" + "github.com/samber/lo" +) + +func (v *Server) NotifyUser(_ context.Context, in *proto.NotifyRequest) (*proto.NotifyReply, error) { + client, err := services.GetThirdClientWithSecret(in.GetClientId(), in.GetClientSecret()) + if err != nil { + return nil, err + } + + var user models.Account + if user, err = services.GetAccount(uint(in.GetRecipientId())); err != nil { + return nil, err + } + + links := lo.Map(in.GetLinks(), func(item *proto.NotifyLink, index int) models.NotificationLink { + return models.NotificationLink{ + Label: item.Label, + Url: item.Url, + } + }) + + if err := services.NewNotification(client, user, in.Subject, in.Content, links, in.IsImportant); err != nil { + return nil, err + } + + return &proto.NotifyReply{IsSent: true}, nil +} diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 8b77cc5..4b244a3 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -1,14 +1,9 @@ package grpc import ( - "context" - "fmt" "net" "git.solsynth.dev/hydrogen/identity/pkg/grpc/proto" - "git.solsynth.dev/hydrogen/identity/pkg/models" - "git.solsynth.dev/hydrogen/identity/pkg/services" - "github.com/samber/lo" "github.com/spf13/viper" "google.golang.org/grpc" "google.golang.org/grpc/reflection" @@ -19,55 +14,6 @@ type Server struct { proto.UnimplementedNotifyServer } -func (v *Server) Authenticate(_ context.Context, in *proto.AuthRequest) (*proto.AuthReply, error) { - user, atk, rtk, err := services.Authenticate(in.GetAccessToken(), in.GetRefreshToken(), 0) - if err != nil { - return &proto.AuthReply{ - IsValid: false, - }, nil - } else { - return &proto.AuthReply{ - IsValid: true, - AccessToken: &atk, - RefreshToken: &rtk, - Userinfo: &proto.Userinfo{ - Id: uint64(user.ID), - Name: user.Name, - Nick: user.Nick, - Email: user.GetPrimaryEmail().Content, - Avatar: fmt.Sprintf("https://%s/api/avatar/%s", viper.GetString("domain"), user.Avatar), - Banner: fmt.Sprintf("https://%s/api/avatar/%s", viper.GetString("domain"), user.Banner), - Description: &user.Description, - }, - }, nil - } -} - -func (v *Server) NotifyUser(_ context.Context, in *proto.NotifyRequest) (*proto.NotifyReply, error) { - client, err := services.GetThirdClientWithSecret(in.GetClientId(), in.GetClientSecret()) - if err != nil { - return nil, err - } - - var user models.Account - if user, err = services.GetAccount(uint(in.GetRecipientId())); err != nil { - return nil, err - } - - links := lo.Map(in.GetLinks(), func(item *proto.NotifyLink, index int) models.NotificationLink { - return models.NotificationLink{ - Label: item.Label, - Url: item.Url, - } - }) - - if err := services.NewNotification(client, user, in.Subject, in.Content, links, in.IsImportant); err != nil { - return nil, err - } - - return &proto.NotifyReply{IsSent: true}, nil -} - func StartGrpc() error { listen, err := net.Listen("tcp", viper.GetString("grpc_bind")) if err != nil { diff --git a/pkg/models/auth.go b/pkg/models/auth.go index 76f51c9..e74d5a3 100644 --- a/pkg/models/auth.go +++ b/pkg/models/auth.go @@ -81,3 +81,9 @@ func (v AuthChallenge) IsAvailable() error { return nil } + +type AuthContext struct { + Session AuthSession `json:"session"` + Account Account `json:"account"` + ExpiredAt time.Time `json:"expired_at"` +} diff --git a/pkg/server/startup.go b/pkg/server/startup.go index 64725eb..e085eaa 100644 --- a/pkg/server/startup.go +++ b/pkg/server/startup.go @@ -28,7 +28,7 @@ func NewServer() { ProxyHeader: fiber.HeaderXForwardedFor, JSONEncoder: jsoniter.ConfigCompatibleWithStandardLibrary.Marshal, JSONDecoder: jsoniter.ConfigCompatibleWithStandardLibrary.Unmarshal, - EnablePrintRoutes: viper.GetBool("debug"), + EnablePrintRoutes: viper.GetBool("debug.print_routes"), }) A.Use(idempotency.New()) diff --git a/pkg/services/auth.go b/pkg/services/auth.go index 57eaec7..bb270fb 100644 --- a/pkg/services/auth.go +++ b/pkg/services/auth.go @@ -2,12 +2,19 @@ package services import ( "fmt" + "time" + "git.solsynth.dev/hydrogen/identity/pkg/database" "git.solsynth.dev/hydrogen/identity/pkg/models" "git.solsynth.dev/hydrogen/identity/pkg/security" "github.com/gofiber/fiber/v2" + jsoniter "github.com/json-iterator/go" + "github.com/rs/zerolog/log" + "go.etcd.io/bbolt" ) +const authContextBucket = "AuthContext" + func Authenticate(access, refresh string, depth int) (models.Account, string, string, error) { var user models.Account claims, err := security.DecodeJwt(access) @@ -22,17 +29,99 @@ func Authenticate(access, refresh string, depth int) (models.Account, string, st return user, access, refresh, fiber.NewError(fiber.StatusUnauthorized, fmt.Sprintf("invalid auth key: %v", err)) } - session, err := LookupSessionWithToken(claims.ID) - if err != nil { - return user, access, refresh, fiber.NewError(fiber.StatusUnauthorized, fmt.Sprintf("invalid auth session: %v", err)) - } else if err := session.IsAvailable(); err != nil { - return user, access, refresh, fiber.NewError(fiber.StatusUnauthorized, fmt.Sprintf("unavailable auth session: %v", err)) + var ctx models.AuthContext + + ctx, lookupErr := GetAuthContext(claims.ID) + if lookupErr == nil { + log.Debug().Str("jti", claims.ID).Msg("Hit auth context cache once!") + return ctx.Account, access, refresh, nil } - user, err = GetAccount(session.AccountID) - if err != nil { - return user, access, refresh, fiber.NewError(fiber.StatusUnauthorized, fmt.Sprintf("invalid account: %v", err)) + ctx, err = GrantAuthContext(claims.ID) + if err == nil { + log.Debug().Str("jti", claims.ID).Err(lookupErr).Msg("Missed auth context cache once!") + return user, access, refresh, nil } - return user, access, refresh, nil + return user, access, refresh, fiber.NewError(fiber.StatusUnauthorized, err.Error()) +} + +func GetAuthContext(jti string) (models.AuthContext, error) { + var err error + var ctx models.AuthContext + + err = database.B.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket([]byte(authContextBucket)) + if bucket == nil { + return fmt.Errorf("unable to find auth context bucket") + } + + raw := bucket.Get([]byte(jti)) + if raw == nil { + return fmt.Errorf("unable to find auth context") + } else if err := jsoniter.Unmarshal(raw, &ctx); err != nil { + return fmt.Errorf("unable to unmarshal auth context: %v", err) + } + + return nil + }) + + if err == nil && time.Now().Unix() >= ctx.ExpiredAt.Unix() { + RevokeAuthContext(jti) + + return ctx, fmt.Errorf("auth context has been expired") + } + + return ctx, err +} + +func GrantAuthContext(jti string) (models.AuthContext, error) { + var ctx models.AuthContext + + // Query data from primary database + session, err := LookupSessionWithToken(jti) + if err != nil { + return ctx, fmt.Errorf("invalid auth session: %v", err) + } else if err := session.IsAvailable(); err != nil { + return ctx, fmt.Errorf("unavailable auth session: %v", err) + } + + user, err := GetAccount(session.AccountID) + if err != nil { + return ctx, fmt.Errorf("invalid account: %v", err) + } + + // Every context should expires in some while + // Once user update their account info, this will have delay to update + ctx = models.AuthContext{ + Session: session, + Account: user, + ExpiredAt: time.Now().Add(5 * time.Minute), + } + + // Save data into KV cache + return ctx, database.B.Update(func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists([]byte(authContextBucket)) + if err != nil { + return err + } + + raw, err := jsoniter.Marshal(ctx) + if err != nil { + return err + } + + return bucket.Put([]byte(jti), raw) + }) +} + +func RevokeAuthContext(jti string) error { + return database.B.Update(func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists([]byte(authContextBucket)) + if err != nil { + return err + } + + return bucket.Delete([]byte(jti)) + }) } diff --git a/settings.toml b/settings.toml index 909124d..fa9c9ed 100644 --- a/settings.toml +++ b/settings.toml @@ -1,5 +1,3 @@ -debug = true - name = "Goatpass" maintainer = "SmartSheep Studio" @@ -12,6 +10,10 @@ content = "uploads" use_registration_magic_token = false +[debug] +database = false +print_routes = false + [external.firebase] credentials = "dist/firebase-certs.json" @@ -32,3 +34,4 @@ refresh_token_duration = 2592000 [database] dsn = "host=localhost dbname=hy_identity port=5432 sslmode=disable" prefix = "identity_" +bolt = "uploads/bolt.db"