✨ Better id token in oidc
This commit is contained in:
		| @@ -23,18 +23,12 @@ type Account struct { | ||||
|  | ||||
| 	Profile  AccountProfile   `json:"profile,omitempty"` | ||||
| 	Contacts []AccountContact `json:"contacts,omitempty"` | ||||
| 	Statuses []Status         `json:"statuses,omitempty"` | ||||
| 	Badges   []Badge          `json:"badges,omitempty"` | ||||
|  | ||||
| 	Identities []RealmMember `json:"identities,omitempty"` | ||||
|  | ||||
| 	Tickets []AuthTicket `json:"tickets,omitempty"` | ||||
| 	Factors []AuthFactor `json:"factors,omitempty"` | ||||
|  | ||||
| 	Events      []ActionEvent `json:"events,omitempty"` | ||||
| 	MagicTokens []MagicToken  `json:"-"` | ||||
|  | ||||
| 	ThirdClients []ThirdClient `json:"clients,omitempty"` | ||||
| 	Events []ActionEvent `json:"events,omitempty"` | ||||
|  | ||||
| 	Notifications     []Notification           `json:"notifications,omitempty"` | ||||
| 	NotifySubscribers []NotificationSubscriber `json:"notify_subscribers,omitempty"` | ||||
|   | ||||
| @@ -109,17 +109,17 @@ func getToken(c *fiber.Ctx) error { | ||||
| 	} | ||||
|  | ||||
| 	var err error | ||||
| 	var access, refresh string | ||||
| 	var idk, atk, rtk string | ||||
| 	switch data.GrantType { | ||||
| 	case "refresh_token": | ||||
| 		// Refresh Token | ||||
| 		access, refresh, err = services.RefreshToken(data.RefreshToken) | ||||
| 		atk, rtk, err = services.RefreshToken(data.RefreshToken) | ||||
| 		if err != nil { | ||||
| 			return fiber.NewError(fiber.StatusBadRequest, err.Error()) | ||||
| 		} | ||||
| 	case "authorization_code": | ||||
| 		// Authorization Code Mode | ||||
| 		access, refresh, err = services.ExchangeOauthToken(data.ClientID, data.ClientSecret, data.RedirectUri, data.Code) | ||||
| 		idk, atk, rtk, err = services.ExchangeOauthToken(data.ClientID, data.ClientSecret, data.RedirectUri, data.Code) | ||||
| 		if err != nil { | ||||
| 			return fiber.NewError(fiber.StatusBadRequest, err.Error()) | ||||
| 		} | ||||
| @@ -139,13 +139,13 @@ func getToken(c *fiber.Ctx) error { | ||||
| 		} else if err := ticket.IsAvailable(); err != nil { | ||||
| 			return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("risk detected: %v (ticketId=%d)", err, ticket.ID)) | ||||
| 		} | ||||
| 		access, refresh, err = services.ExchangeOauthToken(data.ClientID, data.ClientSecret, data.RedirectUri, *ticket.GrantToken) | ||||
| 		idk, atk, rtk, err = services.ExchangeOauthToken(data.ClientID, data.ClientSecret, data.RedirectUri, *ticket.GrantToken) | ||||
| 		if err != nil { | ||||
| 			return fiber.NewError(fiber.StatusBadRequest, err.Error()) | ||||
| 		} | ||||
| 	case "grant_token": | ||||
| 		// Internal Usage | ||||
| 		access, refresh, err = services.ExchangeToken(data.Code) | ||||
| 		atk, rtk, err = services.ExchangeToken(data.Code) | ||||
| 		if err != nil { | ||||
| 			return fiber.NewError(fiber.StatusBadRequest, err.Error()) | ||||
| 		} | ||||
| @@ -153,12 +153,16 @@ func getToken(c *fiber.Ctx) error { | ||||
| 		return fiber.NewError(fiber.StatusBadRequest, "unsupported exchange token type") | ||||
| 	} | ||||
|  | ||||
| 	exts.SetAuthCookies(c, access, refresh) | ||||
| 	if len(idk) == 0 { | ||||
| 		idk = atk | ||||
| 	} | ||||
|  | ||||
| 	exts.SetAuthCookies(c, atk, rtk) | ||||
|  | ||||
| 	return c.JSON(fiber.Map{ | ||||
| 		"id_token":      access, | ||||
| 		"access_token":  access, | ||||
| 		"refresh_token": refresh, | ||||
| 		"id_token":      idk, | ||||
| 		"access_token":  atk, | ||||
| 		"refresh_token": rtk, | ||||
| 		"token_type":    "Bearer", | ||||
| 		"expires_in":    (30 * time.Minute).Seconds(), | ||||
| 	}) | ||||
|   | ||||
| @@ -2,6 +2,7 @@ package api | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
|  | ||||
| 	"git.solsynth.dev/hydrogen/passport/pkg/internal/database" | ||||
| 	"git.solsynth.dev/hydrogen/passport/pkg/internal/models" | ||||
| 	"git.solsynth.dev/hydrogen/passport/pkg/internal/services" | ||||
| @@ -14,7 +15,6 @@ func getOtherUserinfo(c *fiber.Ctx) error { | ||||
| 	var account models.Account | ||||
| 	if err := database.C. | ||||
| 		Where(&models.Account{Name: alias}). | ||||
| 		Omit("tickets", "challenges", "factors", "events", "clients", "notifications", "notify_subscribers"). | ||||
| 		Preload("Profile"). | ||||
| 		Preload("Badges"). | ||||
| 		First(&account).Error; err != nil { | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"time" | ||||
|  | ||||
| 	"git.solsynth.dev/hydrogen/passport/pkg/internal/models" | ||||
| 	"github.com/golang-jwt/jwt/v5" | ||||
| 	"github.com/spf13/viper" | ||||
| ) | ||||
| @@ -11,8 +12,16 @@ import ( | ||||
| type PayloadClaims struct { | ||||
| 	jwt.RegisteredClaims | ||||
|  | ||||
| 	// Internal Stuff | ||||
| 	SessionID string `json:"sed"` | ||||
|  | ||||
| 	// ID Token Stuff | ||||
| 	Name  string `json:"name,omitempty"` | ||||
| 	Nick  string `json:"preferred_username,omitempty"` | ||||
| 	Email string `json:"email,omitempty"` | ||||
|  | ||||
| 	// Additonal Stuff | ||||
| 	AuthorizedParties string `json:"azp,omitempty"` | ||||
| 	SessionID         string `json:"sed"` | ||||
| 	Type              string `json:"typ"` | ||||
| } | ||||
|  | ||||
| @@ -21,7 +30,7 @@ const ( | ||||
| 	JwtRefreshType = "refresh" | ||||
| ) | ||||
|  | ||||
| func EncodeJwt(id string, typ, sub, sed string, aud []string, exp time.Time) (string, error) { | ||||
| func EncodeJwt(id string, typ, sub, sed string, aud []string, exp time.Time, idTokenUser ...models.Account) (string, error) { | ||||
| 	var azp string | ||||
| 	for _, item := range aud { | ||||
| 		if item != InternalTokenAudience { | ||||
| @@ -30,7 +39,7 @@ func EncodeJwt(id string, typ, sub, sed string, aud []string, exp time.Time) (st | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	tk := jwt.NewWithClaims(jwt.SigningMethodHS512, PayloadClaims{ | ||||
| 	claims := PayloadClaims{ | ||||
| 		RegisteredClaims: jwt.RegisteredClaims{ | ||||
| 			Subject:   sub, | ||||
| 			Audience:  aud, | ||||
| @@ -43,7 +52,16 @@ func EncodeJwt(id string, typ, sub, sed string, aud []string, exp time.Time) (st | ||||
| 		AuthorizedParties: azp, | ||||
| 		SessionID:         sed, | ||||
| 		Type:              typ, | ||||
| 	}) | ||||
| 	} | ||||
|  | ||||
| 	if len(idTokenUser) > 0 { | ||||
| 		user := idTokenUser[0] | ||||
| 		claims.Name = user.Name | ||||
| 		claims.Nick = user.Nick | ||||
| 		claims.Email = user.GetPrimaryEmail().Content | ||||
| 	} | ||||
|  | ||||
| 	tk := jwt.NewWithClaims(jwt.SigningMethodHS512, claims) | ||||
|  | ||||
| 	return tk.SignedString([]byte(viper.GetString("secret"))) | ||||
| } | ||||
|   | ||||
| @@ -2,96 +2,119 @@ package services | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"git.solsynth.dev/hydrogen/passport/pkg/internal/database" | ||||
| 	"git.solsynth.dev/hydrogen/passport/pkg/internal/models" | ||||
| 	"github.com/samber/lo" | ||||
| 	"github.com/spf13/viper" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func GetToken(ticket models.AuthTicket) (string, string, error) { | ||||
| 	var refresh, access string | ||||
| 	if err := ticket.IsAvailable(); err != nil { | ||||
| 		return refresh, access, err | ||||
| func GetToken(ticket models.AuthTicket) (atk, rtk string, err error) { | ||||
| 	if err = ticket.IsAvailable(); err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	if ticket.AccessToken == nil || ticket.RefreshToken == nil { | ||||
| 		return refresh, access, fmt.Errorf("unable to encode token, access or refresh token id missing") | ||||
| 		err = fmt.Errorf("unable to encode token, access or refresh token id missing") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	accessDuration := time.Duration(viper.GetInt64("security.access_token_duration")) * time.Second | ||||
| 	refreshDuration := time.Duration(viper.GetInt64("security.refresh_token_duration")) * time.Second | ||||
| 	atkDeadline := time.Duration(viper.GetInt64("security.access_token_duration")) * time.Second | ||||
| 	rtkDeadline := time.Duration(viper.GetInt64("security.refresh_token_duration")) * time.Second | ||||
|  | ||||
| 	var err error | ||||
| 	sub := strconv.Itoa(int(ticket.AccountID)) | ||||
| 	sed := strconv.Itoa(int(ticket.ID)) | ||||
| 	access, err = EncodeJwt(*ticket.AccessToken, JwtAccessType, sub, sed, ticket.Audiences, time.Now().Add(accessDuration)) | ||||
| 	atk, err = EncodeJwt(*ticket.AccessToken, JwtAccessType, sub, sed, ticket.Audiences, time.Now().Add(atkDeadline)) | ||||
| 	if err != nil { | ||||
| 		return refresh, access, err | ||||
| 		return | ||||
| 	} | ||||
| 	refresh, err = EncodeJwt(*ticket.RefreshToken, JwtRefreshType, sub, sed, ticket.Audiences, time.Now().Add(refreshDuration)) | ||||
| 	rtk, err = EncodeJwt(*ticket.RefreshToken, JwtRefreshType, sub, sed, ticket.Audiences, time.Now().Add(rtkDeadline)) | ||||
| 	if err != nil { | ||||
| 		return refresh, access, err | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	ticket.LastGrantAt = lo.ToPtr(time.Now()) | ||||
| 	database.C.Save(&ticket) | ||||
|  | ||||
| 	return access, refresh, nil | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func ExchangeToken(token string) (string, string, error) { | ||||
| func ExchangeToken(token string) (atk, rtk string, err error) { | ||||
| 	var ticket models.AuthTicket | ||||
| 	if err := database.C.Where(models.AuthTicket{GrantToken: &token}).First(&ticket).Error; err != nil { | ||||
| 		return "", "", err | ||||
| 	if err = database.C.Where(models.AuthTicket{GrantToken: &token}).First(&ticket).Error; err != nil { | ||||
| 		return | ||||
| 	} else if ticket.LastGrantAt != nil { | ||||
| 		return "", "", fmt.Errorf("ticket was granted the first token, use refresh token instead") | ||||
| 		err = fmt.Errorf("ticket was granted the first token, use refresh token instead") | ||||
| 		return | ||||
| 	} else if len(ticket.Audiences) > 1 { | ||||
| 		return "", "", fmt.Errorf("should use authorization code grant type") | ||||
| 		err = fmt.Errorf("should use authorization code grant type") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	return GetToken(ticket) | ||||
| } | ||||
|  | ||||
| func ExchangeOauthToken(clientId, clientSecret, redirectUri, token string) (string, string, error) { | ||||
| func ExchangeOauthToken(clientId, clientSecret, redirectUri, token string) (idk, atk, rtk string, err error) { | ||||
| 	var client models.ThirdClient | ||||
| 	if err := database.C.Where(models.ThirdClient{Alias: clientId}).First(&client).Error; err != nil { | ||||
| 		return "", "", err | ||||
| 	if err = database.C.Where(models.ThirdClient{Alias: clientId}).First(&client).Error; err != nil { | ||||
| 		return | ||||
| 	} else if client.Secret != clientSecret { | ||||
| 		return "", "", fmt.Errorf("invalid client secret") | ||||
| 		err = fmt.Errorf("invalid client secret") | ||||
| 		return | ||||
| 	} else if !client.IsDraft && !lo.Contains(client.Callbacks, redirectUri) { | ||||
| 		return "", "", fmt.Errorf("invalid redirect uri") | ||||
| 		err = fmt.Errorf("invalid redirect uri") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var ticket models.AuthTicket | ||||
| 	if err := database.C.Where(models.AuthTicket{GrantToken: &token}).First(&ticket).Error; err != nil { | ||||
| 		return "", "", err | ||||
| 	if err = database.C.Where(models.AuthTicket{GrantToken: &token}).First(&ticket).Error; err != nil { | ||||
| 		return | ||||
| 	} else if ticket.LastGrantAt != nil { | ||||
| 		return "", "", fmt.Errorf("ticket was granted the first token, use refresh token instead") | ||||
| 		err = fmt.Errorf("ticket was granted the first token, use refresh token instead") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	return GetToken(ticket) | ||||
| 	atk, rtk, err = GetToken(ticket) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	var user models.Account | ||||
| 	if err = database.C.Where(models.Account{ | ||||
| 		BaseModel: models.BaseModel{ID: ticket.AccountID}, | ||||
| 	}).Preload("Contacts").First(&user).Error; err != nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	sub := strconv.Itoa(int(ticket.AccountID)) | ||||
| 	sed := strconv.Itoa(int(ticket.ID)) | ||||
| 	idk, err = EncodeJwt(*ticket.AccessToken, JwtAccessType, sub, sed, ticket.Audiences, time.Now().Add(24*time.Minute), user) | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func RefreshToken(token string) (string, string, error) { | ||||
| func RefreshToken(token string) (atk, rtk string, err error) { | ||||
| 	parseInt := func(str string) int { | ||||
| 		val, _ := strconv.Atoi(str) | ||||
| 		return val | ||||
| 	} | ||||
|  | ||||
| 	var ticket models.AuthTicket | ||||
| 	if claims, err := DecodeJwt(token); err != nil { | ||||
| 		return "404", "403", err | ||||
| 	var claims PayloadClaims | ||||
| 	if claims, err = DecodeJwt(token); err != nil { | ||||
| 		return | ||||
| 	} else if claims.Type != JwtRefreshType { | ||||
| 		return "404", "403", fmt.Errorf("invalid token type, expected refresh token") | ||||
| 	} else if err := database.C.Where(models.AuthTicket{ | ||||
| 		err = fmt.Errorf("invalid token type, expected refresh token") | ||||
| 		return | ||||
| 	} else if err = database.C.Where(models.AuthTicket{ | ||||
| 		BaseModel: models.BaseModel{ID: uint(parseInt(claims.SessionID))}, | ||||
| 	}).First(&ticket).Error; err != nil { | ||||
| 		return "404", "403", err | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	if ticket, err := RegenSession(ticket); err != nil { | ||||
| 		return "404", "403", err | ||||
| 	if ticket, err = RegenSession(ticket); err != nil { | ||||
| 		return | ||||
| 	} else { | ||||
| 		return GetToken(ticket) | ||||
| 	} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user