♻️ Use nats jetstream instead of database to store otp

This commit is contained in:
LittleSheep 2025-01-27 15:43:24 +08:00
parent 2dac1759d9
commit 1f75a9e64b
2 changed files with 74 additions and 14 deletions

View File

@ -1,19 +1,32 @@
package gap package gap
import ( import (
"errors"
"fmt" "fmt"
"strings"
"time"
"git.solsynth.dev/hypernet/nexus/pkg/nex" "git.solsynth.dev/hypernet/nexus/pkg/nex"
"git.solsynth.dev/hypernet/nexus/pkg/nex/rx"
"git.solsynth.dev/hypernet/nexus/pkg/proto" "git.solsynth.dev/hypernet/nexus/pkg/proto"
"git.solsynth.dev/hypernet/pusher/pkg/pushkit/pushcon" "git.solsynth.dev/hypernet/pusher/pkg/pushkit/pushcon"
"github.com/nats-io/nats.go"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/samber/lo" "github.com/samber/lo"
"strings"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
var Nx *nex.Conn var (
var Px *pushcon.Conn Nx *nex.Conn
Px *pushcon.Conn
Rx *rx.MqConn
Jt nats.JetStreamContext
)
const (
FactorOtpPrefix = "passport.otp."
)
func InitializeToNexus() error { func InitializeToNexus() error {
grpcBind := strings.SplitN(viper.GetString("grpc_bind"), ":", 2) grpcBind := strings.SplitN(viper.GetString("grpc_bind"), ":", 2)
@ -46,5 +59,25 @@ func InitializeToNexus() error {
return fmt.Errorf("error during initialize pushcon: %v", err) return fmt.Errorf("error during initialize pushcon: %v", err)
} }
Rx, err = rx.NewMqConn(Nx)
if err != nil {
return fmt.Errorf("error during initialize nexus rx module: %v", err)
}
Jt, err = Rx.Nt.JetStream()
if err != nil {
return fmt.Errorf("error during initialize nats jetstream: %v", err)
}
jetstreamCfg := &nats.StreamConfig{
Name: "Passport OTPs",
Subjects: []string{FactorOtpPrefix + ">"},
Storage: nats.MemoryStorage,
MaxAge: 5 * time.Minute,
}
_, err = Jt.AddStream(jetstreamCfg)
if err != nil && !errors.Is(err, nats.ErrStreamNameAlreadyInUse) {
return fmt.Errorf("error during initialize jetstream stream: %v", err)
}
return err return err
} }

View File

@ -2,15 +2,18 @@ package services
import ( import (
"fmt" "fmt"
"strings"
"time"
"git.solsynth.dev/hypernet/passport/pkg/authkit/models" "git.solsynth.dev/hypernet/passport/pkg/authkit/models"
"git.solsynth.dev/hypernet/passport/pkg/internal/database" "git.solsynth.dev/hypernet/passport/pkg/internal/database"
"git.solsynth.dev/hypernet/passport/pkg/internal/gap" "git.solsynth.dev/hypernet/passport/pkg/internal/gap"
"git.solsynth.dev/hypernet/pusher/pkg/pushkit" "git.solsynth.dev/hypernet/pusher/pkg/pushkit"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/nats-io/nats.go"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/spf13/viper" "github.com/spf13/viper"
"strings"
) )
const EmailPasswordTemplate = `Dear %s, const EmailPasswordTemplate = `Dear %s,
@ -75,15 +78,20 @@ func GetFactorCode(factor models.AuthFactor) (bool, error) {
return true, err return true, err
} }
factor.Secret = uuid.NewString()[:6] secret := uuid.NewString()[:6]
if err := database.C.Save(&factor).Error; err != nil {
return true, err identifier := fmt.Sprintf("%s%d", gap.FactorOtpPrefix, factor.ID)
_, err := gap.Jt.Publish(identifier, []byte(secret))
if err != nil {
return true, fmt.Errorf("error during publish message: %v", err)
} else {
log.Info().Uint("factor", factor.ID).Str("secret", secret).Msg("Published one-time-password to JetStream...")
} }
subject := fmt.Sprintf("[%s] Login verification code", viper.GetString("name")) subject := fmt.Sprintf("[%s] Login verification code", viper.GetString("name"))
content := fmt.Sprintf(EmailPasswordTemplate, user.Name, factor.Secret, viper.GetString("maintainer")) content := fmt.Sprintf(EmailPasswordTemplate, user.Name, secret, viper.GetString("maintainer"))
err := gap.Px.PushEmail(pushkit.EmailDeliverRequest{ err = gap.Px.PushEmail(pushkit.EmailDeliverRequest{
To: user.GetPrimaryEmail().Content, To: user.GetPrimaryEmail().Content,
Email: pushkit.EmailData{ Email: pushkit.EmailData{
Subject: subject, Subject: subject,
@ -110,11 +118,30 @@ func CheckFactor(factor models.AuthFactor, code string) error {
fmt.Errorf("invalid password"), fmt.Errorf("invalid password"),
) )
case models.EmailPasswordFactor: case models.EmailPasswordFactor:
return lo.Ternary( identifier := fmt.Sprintf("%s%d", gap.FactorOtpPrefix, factor.ID)
strings.ToUpper(code) == strings.ToUpper(factor.Secret), sub, err := gap.Jt.PullSubscribe(identifier, "otp_consumer", nats.Durable("otp_consumer"))
nil, if err != nil {
fmt.Errorf("invalid verification code"), log.Error().Err(err).Msg("Error subscribing to subject when validating factor code...")
) return fmt.Errorf("error subscribing to subject: %v", err)
}
msgs, err := sub.Fetch(1, nats.MaxWait(2*time.Second))
if err != nil {
log.Error().Err(err).Msg("Error fetching message when validating factor code...")
return fmt.Errorf("error fetching message: %v", err)
}
if len(msgs) > 0 {
msg := msgs[0]
if !strings.EqualFold(code, string(msg.Data)) {
return fmt.Errorf("invalid verification code")
}
log.Info().Uint("factor", factor.ID).Str("secret", code).Msg("Verified one-time-password...")
msg.Ack()
return nil
}
return fmt.Errorf("one-time-password not found or expired")
} }
return nil return nil