From 799bfcc263204d1517757b91e3cbc9a1299e4a45 Mon Sep 17 00:00:00 2001 From: LittleSheep Date: Sun, 20 Oct 2024 21:22:53 +0800 Subject: [PATCH] :sparkles: Validation --- pkg/http/server.go | 4 +-- pkg/nex/command_context.go | 24 ++++++++++++++++ pkg/nex/const.go | 8 ------ pkg/nex/cruda/command.go | 54 ++++++++++++++++++++++++----------- pkg/nex/cruda/command_test.go | 2 +- 5 files changed, 65 insertions(+), 27 deletions(-) delete mode 100644 pkg/nex/const.go diff --git a/pkg/http/server.go b/pkg/http/server.go index 5244544..d51dcf8 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -19,8 +19,8 @@ func NewServer() *HTTPApp { app := fiber.New(fiber.Config{ DisableStartupMessage: true, EnableIPValidation: true, - ServerHeader: "Hydrogen.Nexus", - AppName: "Hydrogen.Nexus", + ServerHeader: "Hypernet.Nexus", + AppName: "Hypernet.Nexus", ProxyHeader: fiber.HeaderXForwardedFor, JSONEncoder: json.Marshal, JSONDecoder: json.Unmarshal, diff --git a/pkg/nex/command_context.go b/pkg/nex/command_context.go index 8f4803c..903f21a 100644 --- a/pkg/nex/command_context.go +++ b/pkg/nex/command_context.go @@ -1,7 +1,9 @@ package nex import ( + "fmt" "github.com/goccy/go-json" + "net/http" "sync" ) @@ -15,6 +17,28 @@ type CommandCtx struct { values sync.Map } +func CtxValueMustBe[T any](c *CommandCtx, key string) (T, error) { + if val, ok := c.values.Load(key); ok { + if v, ok := val.(T); ok { + return v, nil + } + } + var out T + if err := c.Write([]byte(fmt.Sprintf("value %s not found in type %T", key, out)), "text/plain+error", http.StatusBadRequest); err != nil { + return out, err + } + return out, fmt.Errorf("value %s not found", key) +} + +func CtxValueShouldBe[T any](c *CommandCtx, key string, defaultValue T) T { + if val, ok := c.values.Load(key); ok { + if v, ok := val.(T); ok { + return v + } + } + return defaultValue +} + func (c *CommandCtx) Values() map[string]any { duplicate := make(map[string]any) c.values.Range(func(key, value any) bool { diff --git a/pkg/nex/const.go b/pkg/nex/const.go deleted file mode 100644 index 96f255d..0000000 --- a/pkg/nex/const.go +++ /dev/null @@ -1,8 +0,0 @@ -package nex - -const ( - ServiceTypeAuthProvider = "auth" - ServiceTypeFileProvider = "files" - ServiceTypeInteractiveProvider = "interactive" - ServiceTypeMessagingProvider = "messaging" -) diff --git a/pkg/nex/cruda/command.go b/pkg/nex/cruda/command.go index 8bfd04f..460a61b 100644 --- a/pkg/nex/cruda/command.go +++ b/pkg/nex/cruda/command.go @@ -1,7 +1,10 @@ package cruda import ( + "errors" "git.solsynth.dev/hypernet/nexus/pkg/nex" + "github.com/go-playground/validator/v10" + "gorm.io/gorm" "net/http" ) @@ -9,7 +12,7 @@ type CrudAction func(v *CrudConn) nex.CommandHandler func AddModel[T any](v *CrudConn, model T, id, prefix string, tags []string) error { funcList := []CrudAction{cmdList[T], cmdGet[T], cmdCreate[T], cmdUpdate[T], cmdDelete[T]} - funcCmds := []string{".list", ".get", ".create", ".update", ".delete"} + funcCmds := []string{".list", "", "", "", ""} funcMethods := []string{"get", "get", "put", "patch", "delete"} for idx, fn := range funcList { if err := v.Conn.AddCommand(prefix+id+funcCmds[idx], funcMethods[idx], tags, fn(v)); err != nil { @@ -19,10 +22,12 @@ func AddModel[T any](v *CrudConn, model T, id, prefix string, tags []string) err return nil } +var validate = validator.New(validator.WithRequiredStructEnabled()) + func cmdList[T any](c *CrudConn) nex.CommandHandler { return func(ctx *nex.CommandCtx) error { - take := int(ctx.ValueOrElse("query.take", 10).(int64)) - skip := int(ctx.ValueOrElse("query.skip", 0).(int64)) + take := int(nex.CtxValueShouldBe[int64](ctx, "query.take", 10)) + skip := int(nex.CtxValueShouldBe[int64](ctx, "query.skip", 0)) var str T var count int64 @@ -44,10 +49,16 @@ func cmdList[T any](c *CrudConn) nex.CommandHandler { func cmdGet[T any](c *CrudConn) nex.CommandHandler { return func(ctx *nex.CommandCtx) error { - id := ctx.ValueOrElse("query.id", 0).(int64) + id, err := nex.CtxValueMustBe[int64](ctx, "query.id") + if err != nil { + return err + } var out T if err := c.db.First(&out, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ctx.Write([]byte(err.Error()), "text/plain", http.StatusNotFound) + } return err } @@ -57,29 +68,34 @@ func cmdGet[T any](c *CrudConn) nex.CommandHandler { func cmdCreate[T any](c *CrudConn) nex.CommandHandler { return func(ctx *nex.CommandCtx) error { - var out T - if err := ctx.ReadJSON(&out); err != nil { + var payload T + if err := ctx.ReadJSON(&payload); err != nil { return err + } else if err := validate.Struct(payload); err != nil { + return ctx.Write([]byte(err.Error()), "text/plain+error", http.StatusBadRequest) } - // TODO validation - if err := c.db.Create(&out).Error; err != nil { + if err := c.db.Create(&payload).Error; err != nil { return err } - return ctx.JSON(out, http.StatusOK) + return ctx.JSON(payload, http.StatusOK) } } func cmdUpdate[T any](c *CrudConn) nex.CommandHandler { return func(ctx *nex.CommandCtx) error { - id := ctx.ValueOrElse("query.id", 0).(int64) + id, err := nex.CtxValueMustBe[int64](ctx, "query.id") + if err != nil { + return err + } var payload T if err := ctx.ReadJSON(&payload); err != nil { return err + } else if err := validate.Struct(payload); err != nil { + return ctx.Write([]byte(err.Error()), "text/plain+error", http.StatusBadRequest) } - // TODO validation var out T if err := c.db.Model(out).Where("id = ?", id).Updates(&payload).Error; err != nil { @@ -96,13 +112,19 @@ func cmdUpdate[T any](c *CrudConn) nex.CommandHandler { func cmdDelete[T any](c *CrudConn) nex.CommandHandler { return func(ctx *nex.CommandCtx) error { - id := ctx.ValueOrElse("query.id", 0).(int64) - - var out T - if err := c.db.Delete(&out, "id = ?", id).Error; err != nil { + id, err := nex.CtxValueMustBe[int64](ctx, "query.id") + if err != nil { return err } - return ctx.JSON(out, http.StatusOK) + var out T + if err := c.db.Delete(&out, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ctx.Write([]byte(err.Error()), "text/plain", http.StatusNotFound) + } + return err + } + + return ctx.Write(nil, "text/plain", http.StatusOK) } } diff --git a/pkg/nex/cruda/command_test.go b/pkg/nex/cruda/command_test.go index 6e88304..74b7050 100644 --- a/pkg/nex/cruda/command_test.go +++ b/pkg/nex/cruda/command_test.go @@ -11,7 +11,7 @@ import ( type Test struct { cruda.BaseModel - Content string `json:"content"` + Content string `json:"content" validate:"required"` } func TestCrudaCommand(t *testing.T) {