OAuth2 Support

This commit is contained in:
LittleSheep 2024-01-30 15:57:49 +08:00
parent f78ccd8d9d
commit 0497e9717b
15 changed files with 442 additions and 47 deletions

View File

@ -14,6 +14,7 @@ func RunMigration(source *gorm.DB) error {
&models.AuthSession{},
&models.AuthChallenge{},
&models.MagicToken{},
&models.ThirdClient{},
); err != nil {
return err
}

View File

@ -17,17 +17,18 @@ const (
type Account struct {
BaseModel
Name string `json:"name" gorm:"uniqueIndex"`
Nick string `json:"nick"`
State AccountState `json:"state"`
Profile AccountProfile `json:"profile"`
Session []AuthSession `json:"sessions"`
Challenges []AuthChallenge `json:"challenges"`
Factors []AuthFactor `json:"factors"`
Contacts []AccountContact `json:"contacts"`
MagicTokens []MagicToken `json:"-" gorm:"foreignKey:AssignTo"`
ConfirmedAt *time.Time `json:"confirmed_at"`
Permissions datatypes.JSONType[[]string] `json:"permissions"`
Name string `json:"name" gorm:"uniqueIndex"`
Nick string `json:"nick"`
State AccountState `json:"state"`
Profile AccountProfile `json:"profile"`
Sessions []AuthSession `json:"sessions"`
Challenges []AuthChallenge `json:"challenges"`
Factors []AuthFactor `json:"factors"`
Contacts []AccountContact `json:"contacts"`
MagicTokens []MagicToken `json:"-" gorm:"foreignKey:AssignTo"`
ThirdClients []ThirdClient `json:"clients"`
ConfirmedAt *time.Time `json:"confirmed_at"`
Permissions datatypes.JSONType[[]string] `json:"permissions"`
}
func (v Account) GetPrimaryEmail() AccountContact {

View File

@ -27,6 +27,7 @@ type AuthSession struct {
BaseModel
Claims datatypes.JSONSlice[string] `json:"claims"`
Audiences datatypes.JSONSlice[string] `json:"audiences"`
Challenge AuthChallenge `json:"challenge" gorm:"foreignKey:SessionID"`
GrantToken string `json:"grant_token"`
AccessToken string `json:"access_token"`
@ -34,6 +35,7 @@ type AuthSession struct {
ExpiredAt *time.Time `json:"expired_at"`
AvailableAt *time.Time `json:"available_at"`
LastGrantAt *time.Time `json:"last_grant_at"`
ClientID *uint `json:"client_id"`
AccountID uint `json:"account_id"`
}
@ -59,6 +61,7 @@ const (
type AuthChallenge struct {
BaseModel
Location string `json:"location"`
IpAddress string `json:"ip_address"`
UserAgent string `json:"user_agent"`
RiskLevel int `json:"risk_level"`

View File

@ -1,4 +1,17 @@
package models
type OauthClients struct {
import "gorm.io/datatypes"
type ThirdClient struct {
BaseModel
Alias string `json:"alias" gorm:"uniqueIndex"`
Name string `json:"name"`
Description string `json:"description"`
Secret string `json:"secret"`
Urls datatypes.JSONSlice[string] `json:"urls"`
Callbacks datatypes.JSONSlice[string] `json:"callbacks"`
Sessions []AuthSession `json:"sessions" gorm:"foreignKey:ClientID"`
IsDraft bool `json:"is_draft"`
AccountID *uint `json:"account_id"`
}

View File

@ -10,31 +10,36 @@ import (
"gorm.io/datatypes"
)
func NewChallenge(account models.Account, factors []models.AuthFactor, ip, ua string) (models.AuthChallenge, error) {
func CalcRisk(user models.Account, ip, ua string) int {
risk := 3
var secureFactor int64
if err := database.C.Where(models.AuthChallenge{
AccountID: user.ID,
IpAddress: ip,
}).Model(models.AuthChallenge{}).Count(&secureFactor).Error; err == nil {
if secureFactor >= 3 {
risk -= 2
} else if secureFactor >= 1 {
risk -= 1
}
}
return risk
}
func NewChallenge(user models.Account, factors []models.AuthFactor, ip, ua string) (models.AuthChallenge, error) {
var challenge models.AuthChallenge
// Pickup any challenge if possible
if err := database.C.Where(models.AuthChallenge{
AccountID: account.ID,
AccountID: user.ID,
}).Where("state = ?", models.ActiveChallengeState).First(&challenge).Error; err == nil {
return challenge, nil
}
// Reduce the risk level
var secureFactor int64
if err := database.C.Where(models.AuthChallenge{
AccountID: account.ID,
IpAddress: ip,
}).Model(models.AuthChallenge{}).Count(&secureFactor).Error; err != nil {
return challenge, err
}
if secureFactor >= 3 {
risk -= 2
} else if secureFactor >= 1 {
risk -= 1
}
// Calculate the risk level
risk := CalcRisk(user, ip, ua)
// Thinking of the requirements factors
// Clamp risk in the exists requirements factor count
requirements := lo.Clamp(risk, 1, len(factors))
challenge = models.AuthChallenge{
@ -45,7 +50,7 @@ func NewChallenge(account models.Account, factors []models.AuthFactor, ip, ua st
BlacklistFactors: datatypes.NewJSONType([]uint{}),
State: models.ActiveChallengeState,
ExpiredAt: time.Now().Add(2 * time.Hour),
AccountID: account.ID,
AccountID: user.ID,
}
err := database.C.Save(&challenge).Error

View File

@ -11,7 +11,7 @@ import (
"github.com/samber/lo"
)
func GrantSession(challenge models.AuthChallenge, claims []string, expired *time.Time, available *time.Time) (models.AuthSession, error) {
func GrantSession(challenge models.AuthChallenge, claims, audiences []string, expired, available *time.Time) (models.AuthSession, error) {
var session models.AuthSession
if err := challenge.IsAvailable(); err != nil {
return session, err
@ -24,6 +24,7 @@ func GrantSession(challenge models.AuthChallenge, claims []string, expired *time
session = models.AuthSession{
Claims: claims,
Audiences: audiences,
Challenge: challenge,
GrantToken: uuid.NewString(),
AccessToken: uuid.NewString(),
@ -42,7 +43,42 @@ func GrantSession(challenge models.AuthChallenge, claims []string, expired *time
return session, nil
}
func GetToken(session models.AuthSession, aud ...string) (string, string, error) {
func GrantOauthSession(user models.Account, client models.ThirdClient, claims, audiences []string, expired, available *time.Time, ip, ua string) (models.AuthSession, error) {
session := models.AuthSession{
Claims: claims,
Audiences: audiences,
Challenge: models.AuthChallenge{
IpAddress: ip,
UserAgent: ua,
RiskLevel: CalcRisk(user, ip, ua),
State: models.FinishChallengeState,
AccountID: user.ID,
},
GrantToken: uuid.NewString(),
AccessToken: uuid.NewString(),
RefreshToken: uuid.NewString(),
ExpiredAt: expired,
AvailableAt: available,
ClientID: &client.ID,
AccountID: user.ID,
}
if err := database.C.Save(&session).Error; err != nil {
return session, err
}
return session, nil
}
func RegenSession(session models.AuthSession) (models.AuthSession, error) {
session.GrantToken = uuid.NewString()
session.AccessToken = uuid.NewString()
session.RefreshToken = uuid.NewString()
err := database.C.Save(&session).Error
return session, err
}
func GetToken(session models.AuthSession) (string, string, error) {
var refresh, access string
if err := session.IsAvailable(); err != nil {
return refresh, access, err
@ -51,11 +87,11 @@ func GetToken(session models.AuthSession, aud ...string) (string, string, error)
var err error
sub := strconv.Itoa(int(session.ID))
access, err = EncodeJwt(session.AccessToken, nil, JwtAccessType, sub, aud, time.Now().Add(30*time.Minute))
access, err = EncodeJwt(session.AccessToken, nil, JwtAccessType, sub, session.Audiences, time.Now().Add(30*time.Minute))
if err != nil {
return refresh, access, err
}
refresh, err = EncodeJwt(session.RefreshToken, nil, JwtRefreshType, sub, aud, time.Now().Add(30*24*time.Hour))
refresh, err = EncodeJwt(session.RefreshToken, nil, JwtRefreshType, sub, session.Audiences, time.Now().Add(30*24*time.Hour))
if err != nil {
return refresh, access, err
}
@ -66,7 +102,29 @@ func GetToken(session models.AuthSession, aud ...string) (string, string, error)
return access, refresh, nil
}
func ExchangeToken(token string, aud ...string) (string, string, error) {
func ExchangeToken(token string) (string, string, error) {
var session models.AuthSession
if err := database.C.Where(models.AuthSession{GrantToken: token}).First(&session).Error; err != nil {
return "404", "403", err
} else if session.LastGrantAt != nil {
return "404", "403", fmt.Errorf("session was granted the first token, use refresh token instead")
} else if len(session.Audiences) > 1 {
return "404", "403", fmt.Errorf("should use authorization code grant type")
}
return GetToken(session)
}
func ExchangeOauthToken(clientId, clientSecret, redirectUri, token string) (string, string, error) {
var client models.ThirdClient
if err := database.C.Where(models.ThirdClient{Alias: clientId}).First(&client).Error; err != nil {
return "404", "403", err
} else if client.Secret != clientSecret {
return "404", "403", fmt.Errorf("invalid client secret")
} else if !client.IsDraft && !lo.Contains(client.Callbacks, redirectUri) {
return "404", "403", fmt.Errorf("invalid redirect uri")
}
var session models.AuthSession
if err := database.C.Where(models.AuthSession{GrantToken: token}).First(&session).Error; err != nil {
return "404", "403", err
@ -74,10 +132,10 @@ func ExchangeToken(token string, aud ...string) (string, string, error) {
return "404", "403", fmt.Errorf("session was granted the first token, use refresh token instead")
}
return GetToken(session, aud...)
return GetToken(session)
}
func RefreshToken(token string, aud ...string) (string, string, error) {
func RefreshToken(token string) (string, string, error) {
parseInt := func(str string) int {
val, _ := strconv.Atoi(str)
return val
@ -94,5 +152,5 @@ func RefreshToken(token string, aud ...string) (string, string, error) {
return "404", "403", err
}
return GetToken(session, aud...)
return GetToken(session)
}

View File

@ -7,16 +7,20 @@ import (
"fmt"
"github.com/gofiber/fiber/v2"
"github.com/spf13/viper"
"gorm.io/gorm/clause"
)
func getPrincipal(c *fiber.Ctx) error {
user := c.Locals("principal").(models.Account)
var data models.Account
if err := database.C.Where(&models.Account{
BaseModel: models.BaseModel{ID: user.ID},
}).Preload(clause.Associations).First(&data).Error; err != nil {
if err := database.C.
Where(&models.Account{BaseModel: models.BaseModel{ID: user.ID}}).
Preload("Profile").
Preload("Contacts").
Preload("Factors").
Preload("Sessions").
Preload("Challenges").
First(&data).Error; err != nil {
return fiber.NewError(fiber.StatusInternalServerError, err.Error())
}

View File

@ -68,7 +68,7 @@ func doChallenge(c *fiber.Ctx) error {
if err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error())
} else if challenge.Progress >= challenge.Requirements {
session, err := security.GrantSession(challenge, []string{"*"}, nil, lo.ToPtr(time.Now()))
session, err := security.GrantSession(challenge, []string{"*"}, []string{"Hydrogen.Passport"}, nil, lo.ToPtr(time.Now()))
if err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error())
}
@ -89,8 +89,11 @@ func doChallenge(c *fiber.Ctx) error {
func exchangeToken(c *fiber.Ctx) error {
var data struct {
Code string `json:"code"`
GrantType string `json:"grant_type"`
Code string `json:"code" form:"code"`
ClientID string `json:"client_id" form:"client_id"`
ClientSecret string `json:"client_secret" form:"client_secret"`
RedirectUri string `json:"redirect_uri" form:"redirect_uri"`
GrantType string `json:"grant_type" form:"grant_type"`
}
if err := BindAndValidate(c, &data); err != nil {
@ -99,6 +102,18 @@ func exchangeToken(c *fiber.Ctx) error {
switch data.GrantType {
case "authorization_code":
// Authorization Code Mode
access, refresh, err := security.ExchangeOauthToken(data.ClientID, data.ClientSecret, data.RedirectUri, data.Code)
if err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error())
}
return c.JSON(fiber.Map{
"access_token": access,
"refresh_token": refresh,
})
case "grant_token":
// Internal Usage
access, refresh, err := security.ExchangeToken(data.Code)
if err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error())
@ -109,6 +124,7 @@ func exchangeToken(c *fiber.Ctx) error {
"refresh_token": refresh,
})
case "refresh_token":
// Refresh Token
access, refresh, err := security.RefreshToken(data.Code)
if err != nil {
return fiber.NewError(fiber.StatusBadRequest, err.Error())
@ -119,6 +135,6 @@ func exchangeToken(c *fiber.Ctx) error {
"refresh_token": refresh,
})
default:
return fiber.NewError(fiber.StatusBadRequest, "Unsupported exchange token type.")
return fiber.NewError(fiber.StatusBadRequest, "unsupported exchange token type")
}
}

119
pkg/server/oauth_api.go Normal file
View File

@ -0,0 +1,119 @@
package server
import (
"code.smartsheep.studio/hydrogen/passport/pkg/database"
"code.smartsheep.studio/hydrogen/passport/pkg/models"
"code.smartsheep.studio/hydrogen/passport/pkg/security"
"github.com/gofiber/fiber/v2"
"github.com/samber/lo"
"strings"
"time"
)
func preConnect(c *fiber.Ctx) error {
id := c.Query("client_id")
redirect := c.Query("redirect_uri")
var client models.ThirdClient
if err := database.C.Where(&models.ThirdClient{Alias: id}).First(&client).Error; err != nil {
return fiber.NewError(fiber.StatusNotFound, err.Error())
} else if !client.IsDraft && !lo.Contains(client.Callbacks, strings.Split(redirect, "?")[0]) {
return fiber.NewError(fiber.StatusBadRequest, "invalid request url")
}
user := c.Locals("principal").(models.Account)
var session models.AuthSession
if err := database.C.Where(&models.AuthSession{
AccountID: user.ID,
ClientID: &client.ID,
}).First(&session).Error; err == nil {
if session.ExpiredAt != nil && session.ExpiredAt.Unix() < time.Now().Unix() {
return c.JSON(fiber.Map{
"client": client,
"session": nil,
})
} else {
session, err = security.RegenSession(session)
}
return c.JSON(fiber.Map{
"client": client,
"session": session,
})
}
return c.JSON(fiber.Map{
"client": client,
"session": nil,
})
}
func doConnect(c *fiber.Ctx) error {
user := c.Locals("principal").(models.Account)
id := c.Query("client_id")
response := c.Query("response_type")
redirect := c.Query("redirect_uri")
scope := c.Query("scope")
if len(scope) <= 0 {
return fiber.NewError(fiber.StatusBadRequest, "invalid request params")
}
var client models.ThirdClient
if err := database.C.Where(&models.ThirdClient{Alias: id}).First(&client).Error; err != nil {
return fiber.NewError(fiber.StatusNotFound, err.Error())
}
switch response {
case "code":
// OAuth Authorization Mode
expired := time.Now().Add(7 * 24 * time.Hour)
session, err := security.GrantOauthSession(
user,
client,
strings.Split(scope, " "),
[]string{"Hydrogen.Passport", client.Alias},
&expired,
lo.ToPtr(time.Now()),
c.IP(),
c.Get(fiber.HeaderUserAgent),
)
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, err.Error())
} else {
return c.JSON(fiber.Map{
"session": session,
"redirect_uri": redirect,
})
}
case "token":
// OAuth Implicit Mode
expired := time.Now().Add(7 * 24 * time.Hour)
session, err := security.GrantOauthSession(
user,
client,
strings.Split(scope, " "),
[]string{"Hydrogen.Passport", client.Alias},
&expired,
lo.ToPtr(time.Now()),
c.IP(),
c.Get(fiber.HeaderUserAgent),
)
if err != nil {
return fiber.NewError(fiber.StatusInternalServerError, err.Error())
} else if access, refresh, err := security.GetToken(session); err != nil {
return fiber.NewError(fiber.StatusInternalServerError, err.Error())
} else {
return c.JSON(fiber.Map{
"access_token": access,
"refresh_token": refresh,
"redirect_uri": redirect,
"session": session,
})
}
default:
return fiber.NewError(fiber.StatusBadRequest, "unsupported response type")
}
}

View File

@ -31,6 +31,9 @@ func NewServer() {
api.Post("/auth", doChallenge)
api.Post("/auth/token", exchangeToken)
api.Post("/auth/factors/:factorId", requestFactorToken)
api.Get("/auth/oauth/connect", auth, preConnect)
api.Post("/auth/oauth/connect", auth, doConnect)
}
}

View File

@ -21,6 +21,8 @@ render(() => (
<Route path="/" component={lazy(() => import("./pages/dashboard.tsx"))} />
<Route path="/auth/login" component={lazy(() => import("./pages/auth/login.tsx"))} />
<Route path="/auth/register" component={lazy(() => import("./pages/auth/register.tsx"))} />
<Route path="/auth/oauth/connect" component={lazy(() => import("./pages/auth/connect.tsx"))} />
<Route path="/auth/oauth/callback" component={lazy(() => import("./pages/auth/callback.tsx"))} />
<Route path="/users/me/confirm" component={lazy(() => import("./pages/users/confirm.tsx"))} />
</Router>
</UserinfoProvider>

View File

@ -0,0 +1,30 @@
import { useSearchParams } from "@solidjs/router";
export default function DefaultCallbackPage() {
const [searchParams] = useSearchParams();
return (
<div class="w-full h-full flex justify-center items-center">
<div class="card w-[480px] max-w-screen shadow-xl">
<div class="card-body">
<div id="header" class="text-center mb-5">
{/* Just Kidding */}
<h1 class="text-xl font-bold">Default Callback</h1>
<p>
If you see this page, it means some genius developer forgot to set the redirect address, so you visited
this default callback address.
General Douglas MacArthur, a five-star general in the United States, commented on this: "If I let my
soldiers use default callbacks, they'd rather die."
The large documentary film "Callback Legend" is currently in theaters.
</p>
</div>
<div class="text-center">
<p>Authorization Code</p>
<code>{searchParams["code"]}</code>
</div>
</div>
</div>
</div>
);
}

View File

@ -0,0 +1,140 @@
import { createSignal, Show } from "solid-js";
import { useLocation, useSearchParams } from "@solidjs/router";
import { getAtk, useUserinfo } from "../../stores/userinfo.tsx";
export default function OauthConnectPage() {
const [title, setTitle] = createSignal("Connect Third-party");
const [subtitle, setSubtitle] = createSignal("Via your Goatpass account");
const [error, setError] = createSignal<string | null>(null);
const [status, setStatus] = createSignal("Handshaking...");
const [loading, setLoading] = createSignal(true);
const [client, setClient] = createSignal<any>(null);
const [searchParams] = useSearchParams();
const userinfo = useUserinfo();
const location = useLocation();
async function preConnect() {
const res = await fetch(`/api/auth/oauth/connect${location.search}`, {
headers: { "Authorization": `Bearer ${getAtk()}` }
});
if (res.status !== 200) {
setError(await res.text());
} else {
const data = await res.json();
if (data["session"]) {
setStatus("Redirecting...");
redirect(data["session"]);
} else {
setTitle(`Connect ${data["client"].name}`);
setSubtitle(`Via ${userinfo?.displayName}`);
setClient(data["client"]);
setLoading(false);
}
}
}
function decline() {
if (window.history.length > 0) {
window.history.back();
} else {
window.close();
}
}
async function approve() {
setLoading(true);
setStatus("Approving...");
const res = await fetch("/api/auth/oauth/connect?" + new URLSearchParams({
client_id: searchParams["client_id"] as string,
redirect_uri: encodeURIComponent(searchParams["redirect_uri"] as string),
response_type: "code",
scope: searchParams["scope"] as string
}), {
method: "POST",
headers: { "Authorization": `Bearer ${getAtk()}` }
});
if (res.status !== 200) {
setError(await res.text());
setLoading(false);
} else {
const data = await res.json();
setStatus("Redirecting...");
setTimeout(() => redirect(data["session"]), 1850);
}
}
function redirect(session: any) {
const url = `${searchParams["redirect_uri"]}?code=${session["grant_token"]}&state=${searchParams["state"]}`;
window.open(url, "_self");
}
preConnect();
return (
<div class="w-full h-full flex justify-center items-center">
<div class="card w-[480px] max-w-screen shadow-xl">
<div class="card-body">
<div id="header" class="text-center mb-5">
<h1 class="text-xl font-bold">{title()}</h1>
<p>{subtitle()}</p>
</div>
<Show when={error()}>
<div id="alerts" class="mt-1">
<div role="alert" class="alert alert-error">
<svg xmlns="http://www.w3.org/2000/svg" class="stroke-current shrink-0 h-6 w-6" fill="none"
viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2"
d="M10 14l2-2m0 0l2-2m-2 2l-2-2m2 2l2 2m7-2a9 9 0 11-18 0 9 9 0 0118 0z" />
</svg>
<span class="capitalize">{error()}</span>
</div>
</div>
</Show>
<Show when={loading()}>
<div class="py-16 text-center">
<div class="text-center">
<div>
<span class="loading loading-lg loading-bars"></span>
</div>
<span>{status()}</span>
</div>
</div>
</Show>
<Show when={!loading()}>
<div class="mb-3">
<h2 class="font-bold">About who you connecting to</h2>
<p>{client().description}</p>
</div>
<div class="mb-3">
<h2 class="font-bold">Make sure you trust them</h2>
<p>You may share your personal information after connect them. Learn about their privacy policy and user
agreement to keep your personal information in safe.</p>
</div>
<div class="mb-5">
<h2 class="font-bold">After approve this request</h2>
<p>
You will be redirect to{" "}
<span class="link link-primary cursor-not-allowed">{searchParams["redirect_uri"]}</span>
</p>
</div>
<div class="grid grid-cols-1 md:grid-cols-2">
<button class="btn btn-accent" onClick={() => decline()}>Decline</button>
<button class="btn btn-primary" onClick={() => approve()}>Approve</button>
</div>
</Show>
</div>
</div>
</div>
);
}

View File

@ -108,7 +108,7 @@ export default function LoginPage() {
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
code: tk,
grant_type: "authorization_code"
grant_type: "grant_token"
})
});
if (res.status !== 200) {

View File

@ -5,7 +5,7 @@ export default function DashboardPage() {
const userinfo = useUserinfo();
return (
<div class="container mx-auto pt-12">
<div class="max-w-[720px] mx-auto px-5 pt-12">
<h1 class="text-2xl font-bold">Welcome, {userinfo?.displayName}</h1>
<p>What's a nice day!</p>