Better id token in oidc

This commit is contained in:
LittleSheep 2024-07-28 20:04:22 +08:00
parent 94aed49092
commit 6ef46d984d
5 changed files with 97 additions and 58 deletions

View File

@ -23,18 +23,12 @@ type Account struct {
Profile AccountProfile `json:"profile,omitempty"` Profile AccountProfile `json:"profile,omitempty"`
Contacts []AccountContact `json:"contacts,omitempty"` Contacts []AccountContact `json:"contacts,omitempty"`
Statuses []Status `json:"statuses,omitempty"`
Badges []Badge `json:"badges,omitempty"` Badges []Badge `json:"badges,omitempty"`
Identities []RealmMember `json:"identities,omitempty"`
Tickets []AuthTicket `json:"tickets,omitempty"` Tickets []AuthTicket `json:"tickets,omitempty"`
Factors []AuthFactor `json:"factors,omitempty"` Factors []AuthFactor `json:"factors,omitempty"`
Events []ActionEvent `json:"events,omitempty"` Events []ActionEvent `json:"events,omitempty"`
MagicTokens []MagicToken `json:"-"`
ThirdClients []ThirdClient `json:"clients,omitempty"`
Notifications []Notification `json:"notifications,omitempty"` Notifications []Notification `json:"notifications,omitempty"`
NotifySubscribers []NotificationSubscriber `json:"notify_subscribers,omitempty"` NotifySubscribers []NotificationSubscriber `json:"notify_subscribers,omitempty"`

View File

@ -109,17 +109,17 @@ func getToken(c *fiber.Ctx) error {
} }
var err error var err error
var access, refresh string var idk, atk, rtk string
switch data.GrantType { switch data.GrantType {
case "refresh_token": case "refresh_token":
// Refresh Token // Refresh Token
access, refresh, err = services.RefreshToken(data.RefreshToken) atk, rtk, err = services.RefreshToken(data.RefreshToken)
if err != nil { if err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error()) return fiber.NewError(fiber.StatusBadRequest, err.Error())
} }
case "authorization_code": case "authorization_code":
// Authorization Code Mode // Authorization Code Mode
access, refresh, err = services.ExchangeOauthToken(data.ClientID, data.ClientSecret, data.RedirectUri, data.Code) idk, atk, rtk, err = services.ExchangeOauthToken(data.ClientID, data.ClientSecret, data.RedirectUri, data.Code)
if err != nil { if err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error()) return fiber.NewError(fiber.StatusBadRequest, err.Error())
} }
@ -139,13 +139,13 @@ func getToken(c *fiber.Ctx) error {
} else if err := ticket.IsAvailable(); err != nil { } else if err := ticket.IsAvailable(); err != nil {
return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("risk detected: %v (ticketId=%d)", err, ticket.ID)) return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("risk detected: %v (ticketId=%d)", err, ticket.ID))
} }
access, refresh, err = services.ExchangeOauthToken(data.ClientID, data.ClientSecret, data.RedirectUri, *ticket.GrantToken) idk, atk, rtk, err = services.ExchangeOauthToken(data.ClientID, data.ClientSecret, data.RedirectUri, *ticket.GrantToken)
if err != nil { if err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error()) return fiber.NewError(fiber.StatusBadRequest, err.Error())
} }
case "grant_token": case "grant_token":
// Internal Usage // Internal Usage
access, refresh, err = services.ExchangeToken(data.Code) atk, rtk, err = services.ExchangeToken(data.Code)
if err != nil { if err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error()) return fiber.NewError(fiber.StatusBadRequest, err.Error())
} }
@ -153,12 +153,16 @@ func getToken(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusBadRequest, "unsupported exchange token type") return fiber.NewError(fiber.StatusBadRequest, "unsupported exchange token type")
} }
exts.SetAuthCookies(c, access, refresh) if len(idk) == 0 {
idk = atk
}
exts.SetAuthCookies(c, atk, rtk)
return c.JSON(fiber.Map{ return c.JSON(fiber.Map{
"id_token": access, "id_token": idk,
"access_token": access, "access_token": atk,
"refresh_token": refresh, "refresh_token": rtk,
"token_type": "Bearer", "token_type": "Bearer",
"expires_in": (30 * time.Minute).Seconds(), "expires_in": (30 * time.Minute).Seconds(),
}) })

View File

@ -2,6 +2,7 @@ package api
import ( import (
"fmt" "fmt"
"git.solsynth.dev/hydrogen/passport/pkg/internal/database" "git.solsynth.dev/hydrogen/passport/pkg/internal/database"
"git.solsynth.dev/hydrogen/passport/pkg/internal/models" "git.solsynth.dev/hydrogen/passport/pkg/internal/models"
"git.solsynth.dev/hydrogen/passport/pkg/internal/services" "git.solsynth.dev/hydrogen/passport/pkg/internal/services"
@ -14,7 +15,6 @@ func getOtherUserinfo(c *fiber.Ctx) error {
var account models.Account var account models.Account
if err := database.C. if err := database.C.
Where(&models.Account{Name: alias}). Where(&models.Account{Name: alias}).
Omit("tickets", "challenges", "factors", "events", "clients", "notifications", "notify_subscribers").
Preload("Profile"). Preload("Profile").
Preload("Badges"). Preload("Badges").
First(&account).Error; err != nil { First(&account).Error; err != nil {

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"time" "time"
"git.solsynth.dev/hydrogen/passport/pkg/internal/models"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
@ -11,8 +12,16 @@ import (
type PayloadClaims struct { type PayloadClaims struct {
jwt.RegisteredClaims jwt.RegisteredClaims
AuthorizedParties string `json:"azp,omitempty"` // Internal Stuff
SessionID string `json:"sed"` SessionID string `json:"sed"`
// ID Token Stuff
Name string `json:"name,omitempty"`
Nick string `json:"preferred_username,omitempty"`
Email string `json:"email,omitempty"`
// Additonal Stuff
AuthorizedParties string `json:"azp,omitempty"`
Type string `json:"typ"` Type string `json:"typ"`
} }
@ -21,7 +30,7 @@ const (
JwtRefreshType = "refresh" JwtRefreshType = "refresh"
) )
func EncodeJwt(id string, typ, sub, sed string, aud []string, exp time.Time) (string, error) { func EncodeJwt(id string, typ, sub, sed string, aud []string, exp time.Time, idTokenUser ...models.Account) (string, error) {
var azp string var azp string
for _, item := range aud { for _, item := range aud {
if item != InternalTokenAudience { if item != InternalTokenAudience {
@ -30,7 +39,7 @@ func EncodeJwt(id string, typ, sub, sed string, aud []string, exp time.Time) (st
} }
} }
tk := jwt.NewWithClaims(jwt.SigningMethodHS512, PayloadClaims{ claims := PayloadClaims{
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
Subject: sub, Subject: sub,
Audience: aud, Audience: aud,
@ -43,7 +52,16 @@ func EncodeJwt(id string, typ, sub, sed string, aud []string, exp time.Time) (st
AuthorizedParties: azp, AuthorizedParties: azp,
SessionID: sed, SessionID: sed,
Type: typ, Type: typ,
}) }
if len(idTokenUser) > 0 {
user := idTokenUser[0]
claims.Name = user.Name
claims.Nick = user.Nick
claims.Email = user.GetPrimaryEmail().Content
}
tk := jwt.NewWithClaims(jwt.SigningMethodHS512, claims)
return tk.SignedString([]byte(viper.GetString("secret"))) return tk.SignedString([]byte(viper.GetString("secret")))
} }

View File

@ -2,96 +2,119 @@ package services
import ( import (
"fmt" "fmt"
"strconv"
"time"
"git.solsynth.dev/hydrogen/passport/pkg/internal/database" "git.solsynth.dev/hydrogen/passport/pkg/internal/database"
"git.solsynth.dev/hydrogen/passport/pkg/internal/models" "git.solsynth.dev/hydrogen/passport/pkg/internal/models"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/spf13/viper" "github.com/spf13/viper"
"strconv"
"time"
) )
func GetToken(ticket models.AuthTicket) (string, string, error) { func GetToken(ticket models.AuthTicket) (atk, rtk string, err error) {
var refresh, access string if err = ticket.IsAvailable(); err != nil {
if err := ticket.IsAvailable(); err != nil { return
return refresh, access, err
} }
if ticket.AccessToken == nil || ticket.RefreshToken == nil { if ticket.AccessToken == nil || ticket.RefreshToken == nil {
return refresh, access, fmt.Errorf("unable to encode token, access or refresh token id missing") err = fmt.Errorf("unable to encode token, access or refresh token id missing")
return
} }
accessDuration := time.Duration(viper.GetInt64("security.access_token_duration")) * time.Second atkDeadline := time.Duration(viper.GetInt64("security.access_token_duration")) * time.Second
refreshDuration := time.Duration(viper.GetInt64("security.refresh_token_duration")) * time.Second rtkDeadline := time.Duration(viper.GetInt64("security.refresh_token_duration")) * time.Second
var err error
sub := strconv.Itoa(int(ticket.AccountID)) sub := strconv.Itoa(int(ticket.AccountID))
sed := strconv.Itoa(int(ticket.ID)) sed := strconv.Itoa(int(ticket.ID))
access, err = EncodeJwt(*ticket.AccessToken, JwtAccessType, sub, sed, ticket.Audiences, time.Now().Add(accessDuration)) atk, err = EncodeJwt(*ticket.AccessToken, JwtAccessType, sub, sed, ticket.Audiences, time.Now().Add(atkDeadline))
if err != nil { if err != nil {
return refresh, access, err return
} }
refresh, err = EncodeJwt(*ticket.RefreshToken, JwtRefreshType, sub, sed, ticket.Audiences, time.Now().Add(refreshDuration)) rtk, err = EncodeJwt(*ticket.RefreshToken, JwtRefreshType, sub, sed, ticket.Audiences, time.Now().Add(rtkDeadline))
if err != nil { if err != nil {
return refresh, access, err return
} }
ticket.LastGrantAt = lo.ToPtr(time.Now()) ticket.LastGrantAt = lo.ToPtr(time.Now())
database.C.Save(&ticket) database.C.Save(&ticket)
return access, refresh, nil return
} }
func ExchangeToken(token string) (string, string, error) { func ExchangeToken(token string) (atk, rtk string, err error) {
var ticket models.AuthTicket var ticket models.AuthTicket
if err := database.C.Where(models.AuthTicket{GrantToken: &token}).First(&ticket).Error; err != nil { if err = database.C.Where(models.AuthTicket{GrantToken: &token}).First(&ticket).Error; err != nil {
return "", "", err return
} else if ticket.LastGrantAt != nil { } else if ticket.LastGrantAt != nil {
return "", "", fmt.Errorf("ticket was granted the first token, use refresh token instead") err = fmt.Errorf("ticket was granted the first token, use refresh token instead")
return
} else if len(ticket.Audiences) > 1 { } else if len(ticket.Audiences) > 1 {
return "", "", fmt.Errorf("should use authorization code grant type") err = fmt.Errorf("should use authorization code grant type")
return
} }
return GetToken(ticket) return GetToken(ticket)
} }
func ExchangeOauthToken(clientId, clientSecret, redirectUri, token string) (string, string, error) { func ExchangeOauthToken(clientId, clientSecret, redirectUri, token string) (idk, atk, rtk string, err error) {
var client models.ThirdClient var client models.ThirdClient
if err := database.C.Where(models.ThirdClient{Alias: clientId}).First(&client).Error; err != nil { if err = database.C.Where(models.ThirdClient{Alias: clientId}).First(&client).Error; err != nil {
return "", "", err return
} else if client.Secret != clientSecret { } else if client.Secret != clientSecret {
return "", "", fmt.Errorf("invalid client secret") err = fmt.Errorf("invalid client secret")
return
} else if !client.IsDraft && !lo.Contains(client.Callbacks, redirectUri) { } else if !client.IsDraft && !lo.Contains(client.Callbacks, redirectUri) {
return "", "", fmt.Errorf("invalid redirect uri") err = fmt.Errorf("invalid redirect uri")
return
} }
var ticket models.AuthTicket var ticket models.AuthTicket
if err := database.C.Where(models.AuthTicket{GrantToken: &token}).First(&ticket).Error; err != nil { if err = database.C.Where(models.AuthTicket{GrantToken: &token}).First(&ticket).Error; err != nil {
return "", "", err return
} else if ticket.LastGrantAt != nil { } else if ticket.LastGrantAt != nil {
return "", "", fmt.Errorf("ticket was granted the first token, use refresh token instead") err = fmt.Errorf("ticket was granted the first token, use refresh token instead")
return
} }
return GetToken(ticket) atk, rtk, err = GetToken(ticket)
if err != nil {
return
}
var user models.Account
if err = database.C.Where(models.Account{
BaseModel: models.BaseModel{ID: ticket.AccountID},
}).Preload("Contacts").First(&user).Error; err != nil {
return
}
sub := strconv.Itoa(int(ticket.AccountID))
sed := strconv.Itoa(int(ticket.ID))
idk, err = EncodeJwt(*ticket.AccessToken, JwtAccessType, sub, sed, ticket.Audiences, time.Now().Add(24*time.Minute), user)
return
} }
func RefreshToken(token string) (string, string, error) { func RefreshToken(token string) (atk, rtk string, err error) {
parseInt := func(str string) int { parseInt := func(str string) int {
val, _ := strconv.Atoi(str) val, _ := strconv.Atoi(str)
return val return val
} }
var ticket models.AuthTicket var ticket models.AuthTicket
if claims, err := DecodeJwt(token); err != nil { var claims PayloadClaims
return "404", "403", err if claims, err = DecodeJwt(token); err != nil {
return
} else if claims.Type != JwtRefreshType { } else if claims.Type != JwtRefreshType {
return "404", "403", fmt.Errorf("invalid token type, expected refresh token") err = fmt.Errorf("invalid token type, expected refresh token")
} else if err := database.C.Where(models.AuthTicket{ return
} else if err = database.C.Where(models.AuthTicket{
BaseModel: models.BaseModel{ID: uint(parseInt(claims.SessionID))}, BaseModel: models.BaseModel{ID: uint(parseInt(claims.SessionID))},
}).First(&ticket).Error; err != nil { }).First(&ticket).Error; err != nil {
return "404", "403", err return
} }
if ticket, err := RegenSession(ticket); err != nil { if ticket, err = RegenSession(ticket); err != nil {
return "404", "403", err return
} else { } else {
return GetToken(ticket) return GetToken(ticket)
} }