diff --git a/pkg/internal/auth/http.go b/pkg/internal/auth/http.go index 81cc748..1e419ee 100644 --- a/pkg/internal/auth/http.go +++ b/pkg/internal/auth/http.go @@ -4,7 +4,7 @@ import ( "github.com/gofiber/fiber/v2" ) -func AuthContextMiddleware(c *fiber.Ctx) error { +func ContextMiddleware(c *fiber.Ctx) error { atk := tokenExtract(c) c.Locals("nex_in_token", atk) @@ -20,7 +20,7 @@ func AuthContextMiddleware(c *fiber.Ctx) error { return c.Next() } -func AuthMiddleware(c *fiber.Ctx) error { +func ValidatorMiddleware(c *fiber.Ctx) error { if c.Locals("nex_principal") == nil { err := c.Locals("nex_auth_error").(error) return fiber.NewError(fiber.StatusUnauthorized, err.Error()) diff --git a/pkg/internal/auth/userinfo.go b/pkg/internal/auth/userinfo.go index 9b6e90e..98fe8b8 100644 --- a/pkg/internal/auth/userinfo.go +++ b/pkg/internal/auth/userinfo.go @@ -33,6 +33,7 @@ func userinfoFetch(c *fiber.Ctx) error { return fiber.NewError(fiber.StatusUnauthorized, fmt.Sprintf("unable to load userinfo: %v", err)) } userinfo := sec.NewUserInfoFromProto(resp.Info.Info) + c.Locals("nex_user", userinfo) tk, err := IWriter.WriteUserInfoJwt(userinfo) if err != nil { return fiber.NewError(fiber.StatusInternalServerError, fmt.Sprintf("unable to sign userinfo: %v", err)) diff --git a/pkg/internal/http/api/forward.go b/pkg/internal/http/api/forward.go index ddde6a9..a3ddf02 100644 --- a/pkg/internal/http/api/forward.go +++ b/pkg/internal/http/api/forward.go @@ -1,6 +1,7 @@ package api import ( + "fmt" "git.solsynth.dev/hypernet/nexus/pkg/internal/directory" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/proxy" @@ -29,6 +30,12 @@ func forwardService(c *fiber.Ctx) error { url = strings.Replace(url, "/cgi/"+ogKeyword, "", 1) url = *service.HttpAddr + url + if tk, ok := c.Locals("nex_token").(string); ok { + c.Set(fiber.HeaderAuthorization, fmt.Sprintf("Bearer %s", tk)) + } else { + c.Set(fiber.HeaderAuthorization, "") + } + log.Debug(). Str("from", ogUrl). Str("to", url). diff --git a/pkg/internal/http/api/index.go b/pkg/internal/http/api/index.go index cea0a87..2e0532d 100644 --- a/pkg/internal/http/api/index.go +++ b/pkg/internal/http/api/index.go @@ -1,12 +1,15 @@ package api import ( + "git.solsynth.dev/hypernet/nexus/pkg/internal/auth" "git.solsynth.dev/hypernet/nexus/pkg/internal/http/ws" "github.com/gofiber/contrib/websocket" "github.com/gofiber/fiber/v2" ) func MapAPIs(app *fiber.App) { + app.Use(auth.ContextMiddleware) + // Some built-in public-accessible APIs wellKnown := app.Group("/.well-known").Name("Well Known") { @@ -18,12 +21,7 @@ func MapAPIs(app *fiber.App) { } // Common websocket gateway - app.Use(func(c *fiber.Ctx) error { - /*if err := exts.EnsureAuthenticated(c); err != nil { - return err - }*/ - return c.Next() - }).Get("/ws", websocket.New(ws.Listen)) + app.Use(auth.ValidatorMiddleware).Get("/ws", websocket.New(ws.Listen)) app.All("/inv/:command", invokeCommand) app.All("/cgi/:service/*", forwardService) diff --git a/pkg/internal/http/ws/connections.go b/pkg/internal/http/ws/connections.go index 673d933..b544915 100644 --- a/pkg/internal/http/ws/connections.go +++ b/pkg/internal/http/ws/connections.go @@ -2,10 +2,10 @@ package ws import ( "git.solsynth.dev/hypernet/nexus/pkg/internal/directory" + "git.solsynth.dev/hypernet/nexus/pkg/nex/sec" "math/rand" "sync" - "git.solsynth.dev/hypernet/nexus/pkg/internal/models" "github.com/gofiber/contrib/websocket" ) @@ -14,7 +14,7 @@ var ( wsConn = make(map[uint]map[uint64]*websocket.Conn) ) -func ClientRegister(user models.Account, conn *websocket.Conn) uint64 { +func ClientRegister(user sec.UserInfo, conn *websocket.Conn) uint64 { wsMutex.Lock() if wsConn[user.ID] == nil { wsConn[user.ID] = make(map[uint64]*websocket.Conn) @@ -31,7 +31,7 @@ func ClientRegister(user models.Account, conn *websocket.Conn) uint64 { return clientId } -func ClientUnregister(user models.Account, id uint64) { +func ClientUnregister(user sec.UserInfo, id uint64) { wsMutex.Lock() if wsConn[user.ID] == nil { wsConn[user.ID] = make(map[uint64]*websocket.Conn) diff --git a/pkg/internal/http/ws/ws.go b/pkg/internal/http/ws/ws.go index e8147b5..4749251 100644 --- a/pkg/internal/http/ws/ws.go +++ b/pkg/internal/http/ws/ws.go @@ -1,8 +1,8 @@ package ws import ( - "git.solsynth.dev/hypernet/nexus/pkg/internal/models" "git.solsynth.dev/hypernet/nexus/pkg/nex" + "git.solsynth.dev/hypernet/nexus/pkg/nex/sec" "github.com/gofiber/contrib/websocket" jsoniter "github.com/json-iterator/go" "github.com/rs/zerolog/log" @@ -10,7 +10,7 @@ import ( ) func Listen(c *websocket.Conn) { - user := c.Locals("user").(models.Account) + user := c.Locals("nex_user").(sec.UserInfo) // Push connection clientId := ClientRegister(user, c) diff --git a/pkg/nex/sec/adaptor.go b/pkg/nex/sec/adaptor.go new file mode 100644 index 0000000..c09e184 --- /dev/null +++ b/pkg/nex/sec/adaptor.go @@ -0,0 +1,39 @@ +package sec + +import ( + "github.com/gofiber/fiber/v2" + "strings" +) + +// ContextMiddleware provide a middleware to receive the userinfo from the nexus. +// It only works on the client-side of nexus. +// It will NOT validate the auth status if you need to validate the status of current authorization, refer to ValidatorMiddleware. +// To get the userinfo, call `c.Locals('nex_user').(sec.UserInfo)` +// Make sure you got the right public key, otherwise the auth will fail. +func ContextMiddleware(tkReader *InternalTokenReader) fiber.Handler { + return func(c *fiber.Ctx) error { + token := c.Get(fiber.HeaderAuthorization) + token = strings.TrimSpace(strings.Replace(token, "Bearer ", "", 1)) + if len(token) == 0 { + return fiber.NewError(fiber.StatusUnauthorized, "no authorization token is provided") + } + + data, err := tkReader.ReadUserInfoJwt(token) + if err != nil { + return fiber.NewError(fiber.StatusUnauthorized, err.Error()) + } + c.Locals("nex_user", data) + + return c.Next() + } +} + +// ValidatorMiddleware will ensure request is authenticated +// Make sure call this middleware after ContextMiddleware +func ValidatorMiddleware(c *fiber.Ctx) error { + if c.Locals("nex_user") == nil { + return fiber.NewError(fiber.StatusUnauthorized, "unauthorized") + } + + return c.Next() +}