From 6ef46d984d3bd4c3e6c4fb2b5fc526777a5e3716 Mon Sep 17 00:00:00 2001 From: LittleSheep Date: Sun, 28 Jul 2024 20:04:22 +0800 Subject: [PATCH] :sparkles: Better id token in oidc --- pkg/internal/models/accounts.go | 8 +- pkg/internal/server/api/auth_api.go | 22 +++--- pkg/internal/server/api/userinfo_api.go | 2 +- pkg/internal/services/jwt.go | 26 ++++++- pkg/internal/services/ticket_token.go | 97 +++++++++++++++---------- 5 files changed, 97 insertions(+), 58 deletions(-) diff --git a/pkg/internal/models/accounts.go b/pkg/internal/models/accounts.go index be6dc47..749d09a 100644 --- a/pkg/internal/models/accounts.go +++ b/pkg/internal/models/accounts.go @@ -23,18 +23,12 @@ type Account struct { Profile AccountProfile `json:"profile,omitempty"` Contacts []AccountContact `json:"contacts,omitempty"` - Statuses []Status `json:"statuses,omitempty"` Badges []Badge `json:"badges,omitempty"` - Identities []RealmMember `json:"identities,omitempty"` - Tickets []AuthTicket `json:"tickets,omitempty"` Factors []AuthFactor `json:"factors,omitempty"` - Events []ActionEvent `json:"events,omitempty"` - MagicTokens []MagicToken `json:"-"` - - ThirdClients []ThirdClient `json:"clients,omitempty"` + Events []ActionEvent `json:"events,omitempty"` Notifications []Notification `json:"notifications,omitempty"` NotifySubscribers []NotificationSubscriber `json:"notify_subscribers,omitempty"` diff --git a/pkg/internal/server/api/auth_api.go b/pkg/internal/server/api/auth_api.go index f421073..2fb5c95 100644 --- a/pkg/internal/server/api/auth_api.go +++ b/pkg/internal/server/api/auth_api.go @@ -109,17 +109,17 @@ func getToken(c *fiber.Ctx) error { } var err error - var access, refresh string + var idk, atk, rtk string switch data.GrantType { case "refresh_token": // Refresh Token - access, refresh, err = services.RefreshToken(data.RefreshToken) + atk, rtk, err = services.RefreshToken(data.RefreshToken) if err != nil { return fiber.NewError(fiber.StatusBadRequest, err.Error()) } case "authorization_code": // Authorization Code Mode - access, refresh, err = services.ExchangeOauthToken(data.ClientID, data.ClientSecret, data.RedirectUri, data.Code) + idk, atk, rtk, err = services.ExchangeOauthToken(data.ClientID, data.ClientSecret, data.RedirectUri, data.Code) if err != nil { return fiber.NewError(fiber.StatusBadRequest, err.Error()) } @@ -139,13 +139,13 @@ func getToken(c *fiber.Ctx) error { } else if err := ticket.IsAvailable(); err != nil { return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("risk detected: %v (ticketId=%d)", err, ticket.ID)) } - access, refresh, err = services.ExchangeOauthToken(data.ClientID, data.ClientSecret, data.RedirectUri, *ticket.GrantToken) + idk, atk, rtk, err = services.ExchangeOauthToken(data.ClientID, data.ClientSecret, data.RedirectUri, *ticket.GrantToken) if err != nil { return fiber.NewError(fiber.StatusBadRequest, err.Error()) } case "grant_token": // Internal Usage - access, refresh, err = services.ExchangeToken(data.Code) + atk, rtk, err = services.ExchangeToken(data.Code) if err != nil { return fiber.NewError(fiber.StatusBadRequest, err.Error()) } @@ -153,12 +153,16 @@ func getToken(c *fiber.Ctx) error { return fiber.NewError(fiber.StatusBadRequest, "unsupported exchange token type") } - exts.SetAuthCookies(c, access, refresh) + if len(idk) == 0 { + idk = atk + } + + exts.SetAuthCookies(c, atk, rtk) return c.JSON(fiber.Map{ - "id_token": access, - "access_token": access, - "refresh_token": refresh, + "id_token": idk, + "access_token": atk, + "refresh_token": rtk, "token_type": "Bearer", "expires_in": (30 * time.Minute).Seconds(), }) diff --git a/pkg/internal/server/api/userinfo_api.go b/pkg/internal/server/api/userinfo_api.go index d1196e4..bf9471e 100644 --- a/pkg/internal/server/api/userinfo_api.go +++ b/pkg/internal/server/api/userinfo_api.go @@ -2,6 +2,7 @@ package api import ( "fmt" + "git.solsynth.dev/hydrogen/passport/pkg/internal/database" "git.solsynth.dev/hydrogen/passport/pkg/internal/models" "git.solsynth.dev/hydrogen/passport/pkg/internal/services" @@ -14,7 +15,6 @@ func getOtherUserinfo(c *fiber.Ctx) error { var account models.Account if err := database.C. Where(&models.Account{Name: alias}). - Omit("tickets", "challenges", "factors", "events", "clients", "notifications", "notify_subscribers"). Preload("Profile"). Preload("Badges"). First(&account).Error; err != nil { diff --git a/pkg/internal/services/jwt.go b/pkg/internal/services/jwt.go index 0a3ebd7..3e5597c 100644 --- a/pkg/internal/services/jwt.go +++ b/pkg/internal/services/jwt.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "git.solsynth.dev/hydrogen/passport/pkg/internal/models" "github.com/golang-jwt/jwt/v5" "github.com/spf13/viper" ) @@ -11,8 +12,16 @@ import ( type PayloadClaims struct { jwt.RegisteredClaims + // Internal Stuff + SessionID string `json:"sed"` + + // ID Token Stuff + Name string `json:"name,omitempty"` + Nick string `json:"preferred_username,omitempty"` + Email string `json:"email,omitempty"` + + // Additonal Stuff AuthorizedParties string `json:"azp,omitempty"` - SessionID string `json:"sed"` Type string `json:"typ"` } @@ -21,7 +30,7 @@ const ( JwtRefreshType = "refresh" ) -func EncodeJwt(id string, typ, sub, sed string, aud []string, exp time.Time) (string, error) { +func EncodeJwt(id string, typ, sub, sed string, aud []string, exp time.Time, idTokenUser ...models.Account) (string, error) { var azp string for _, item := range aud { if item != InternalTokenAudience { @@ -30,7 +39,7 @@ func EncodeJwt(id string, typ, sub, sed string, aud []string, exp time.Time) (st } } - tk := jwt.NewWithClaims(jwt.SigningMethodHS512, PayloadClaims{ + claims := PayloadClaims{ RegisteredClaims: jwt.RegisteredClaims{ Subject: sub, Audience: aud, @@ -43,7 +52,16 @@ func EncodeJwt(id string, typ, sub, sed string, aud []string, exp time.Time) (st AuthorizedParties: azp, SessionID: sed, Type: typ, - }) + } + + if len(idTokenUser) > 0 { + user := idTokenUser[0] + claims.Name = user.Name + claims.Nick = user.Nick + claims.Email = user.GetPrimaryEmail().Content + } + + tk := jwt.NewWithClaims(jwt.SigningMethodHS512, claims) return tk.SignedString([]byte(viper.GetString("secret"))) } diff --git a/pkg/internal/services/ticket_token.go b/pkg/internal/services/ticket_token.go index 4cb4772..214f65d 100644 --- a/pkg/internal/services/ticket_token.go +++ b/pkg/internal/services/ticket_token.go @@ -2,96 +2,119 @@ package services import ( "fmt" + "strconv" + "time" + "git.solsynth.dev/hydrogen/passport/pkg/internal/database" "git.solsynth.dev/hydrogen/passport/pkg/internal/models" "github.com/samber/lo" "github.com/spf13/viper" - "strconv" - "time" ) -func GetToken(ticket models.AuthTicket) (string, string, error) { - var refresh, access string - if err := ticket.IsAvailable(); err != nil { - return refresh, access, err +func GetToken(ticket models.AuthTicket) (atk, rtk string, err error) { + if err = ticket.IsAvailable(); err != nil { + return } if ticket.AccessToken == nil || ticket.RefreshToken == nil { - return refresh, access, fmt.Errorf("unable to encode token, access or refresh token id missing") + err = fmt.Errorf("unable to encode token, access or refresh token id missing") + return } - accessDuration := time.Duration(viper.GetInt64("security.access_token_duration")) * time.Second - refreshDuration := time.Duration(viper.GetInt64("security.refresh_token_duration")) * time.Second + atkDeadline := time.Duration(viper.GetInt64("security.access_token_duration")) * time.Second + rtkDeadline := time.Duration(viper.GetInt64("security.refresh_token_duration")) * time.Second - var err error sub := strconv.Itoa(int(ticket.AccountID)) sed := strconv.Itoa(int(ticket.ID)) - access, err = EncodeJwt(*ticket.AccessToken, JwtAccessType, sub, sed, ticket.Audiences, time.Now().Add(accessDuration)) + atk, err = EncodeJwt(*ticket.AccessToken, JwtAccessType, sub, sed, ticket.Audiences, time.Now().Add(atkDeadline)) if err != nil { - return refresh, access, err + return } - refresh, err = EncodeJwt(*ticket.RefreshToken, JwtRefreshType, sub, sed, ticket.Audiences, time.Now().Add(refreshDuration)) + rtk, err = EncodeJwt(*ticket.RefreshToken, JwtRefreshType, sub, sed, ticket.Audiences, time.Now().Add(rtkDeadline)) if err != nil { - return refresh, access, err + return } ticket.LastGrantAt = lo.ToPtr(time.Now()) database.C.Save(&ticket) - return access, refresh, nil + return } -func ExchangeToken(token string) (string, string, error) { +func ExchangeToken(token string) (atk, rtk string, err error) { var ticket models.AuthTicket - if err := database.C.Where(models.AuthTicket{GrantToken: &token}).First(&ticket).Error; err != nil { - return "", "", err + if err = database.C.Where(models.AuthTicket{GrantToken: &token}).First(&ticket).Error; err != nil { + return } else if ticket.LastGrantAt != nil { - return "", "", fmt.Errorf("ticket was granted the first token, use refresh token instead") + err = fmt.Errorf("ticket was granted the first token, use refresh token instead") + return } else if len(ticket.Audiences) > 1 { - return "", "", fmt.Errorf("should use authorization code grant type") + err = fmt.Errorf("should use authorization code grant type") + return } return GetToken(ticket) } -func ExchangeOauthToken(clientId, clientSecret, redirectUri, token string) (string, string, error) { +func ExchangeOauthToken(clientId, clientSecret, redirectUri, token string) (idk, atk, rtk string, err error) { var client models.ThirdClient - if err := database.C.Where(models.ThirdClient{Alias: clientId}).First(&client).Error; err != nil { - return "", "", err + if err = database.C.Where(models.ThirdClient{Alias: clientId}).First(&client).Error; err != nil { + return } else if client.Secret != clientSecret { - return "", "", fmt.Errorf("invalid client secret") + err = fmt.Errorf("invalid client secret") + return } else if !client.IsDraft && !lo.Contains(client.Callbacks, redirectUri) { - return "", "", fmt.Errorf("invalid redirect uri") + err = fmt.Errorf("invalid redirect uri") + return } var ticket models.AuthTicket - if err := database.C.Where(models.AuthTicket{GrantToken: &token}).First(&ticket).Error; err != nil { - return "", "", err + if err = database.C.Where(models.AuthTicket{GrantToken: &token}).First(&ticket).Error; err != nil { + return } else if ticket.LastGrantAt != nil { - return "", "", fmt.Errorf("ticket was granted the first token, use refresh token instead") + err = fmt.Errorf("ticket was granted the first token, use refresh token instead") + return } - return GetToken(ticket) + atk, rtk, err = GetToken(ticket) + if err != nil { + return + } + + var user models.Account + if err = database.C.Where(models.Account{ + BaseModel: models.BaseModel{ID: ticket.AccountID}, + }).Preload("Contacts").First(&user).Error; err != nil { + return + } + + sub := strconv.Itoa(int(ticket.AccountID)) + sed := strconv.Itoa(int(ticket.ID)) + idk, err = EncodeJwt(*ticket.AccessToken, JwtAccessType, sub, sed, ticket.Audiences, time.Now().Add(24*time.Minute), user) + + return } -func RefreshToken(token string) (string, string, error) { +func RefreshToken(token string) (atk, rtk string, err error) { parseInt := func(str string) int { val, _ := strconv.Atoi(str) return val } var ticket models.AuthTicket - if claims, err := DecodeJwt(token); err != nil { - return "404", "403", err + var claims PayloadClaims + if claims, err = DecodeJwt(token); err != nil { + return } else if claims.Type != JwtRefreshType { - return "404", "403", fmt.Errorf("invalid token type, expected refresh token") - } else if err := database.C.Where(models.AuthTicket{ + err = fmt.Errorf("invalid token type, expected refresh token") + return + } else if err = database.C.Where(models.AuthTicket{ BaseModel: models.BaseModel{ID: uint(parseInt(claims.SessionID))}, }).First(&ticket).Error; err != nil { - return "404", "403", err + return } - if ticket, err := RegenSession(ticket); err != nil { - return "404", "403", err + if ticket, err = RegenSession(ticket); err != nil { + return } else { return GetToken(ticket) }