diff --git a/pkg/authkit/models/punishments.go b/pkg/authkit/models/punishments.go new file mode 100644 index 0000000..a62842b --- /dev/null +++ b/pkg/authkit/models/punishments.go @@ -0,0 +1,26 @@ +package models + +import ( + "time" + + "gorm.io/datatypes" +) + +const ( + PunishmentTypeStrike = iota + PunishmentTypeLimited + PunishmentTypeDisabled +) + +type Punishment struct { + BaseModel + + Reason string `json:"reason"` + Type int `json:"type"` + PermNodes datatypes.JSONMap `json:"perm_nodes"` + ExpiredAt *time.Time `json:"expired_at"` + Account Account `json:"account"` + AccountID uint `json:"account_id"` + Moderator *Account `json:"moderator"` + ModeratorID *uint `json:"moderator_id"` +} diff --git a/pkg/internal/services/punishments.go b/pkg/internal/services/punishments.go new file mode 100644 index 0000000..bd2294f --- /dev/null +++ b/pkg/internal/services/punishments.go @@ -0,0 +1,150 @@ +package services + +import ( + "fmt" + "time" + + "git.solsynth.dev/hypernet/passport/pkg/authkit/models" + "git.solsynth.dev/hypernet/passport/pkg/internal/database" +) + +func NewPunishment(in models.Punishment, moderator ...models.Account) (models.Punishment, error) { + if len(moderator) > 0 { + in.Moderator = &moderator[0] + in.ModeratorID = &moderator[0].ID + } + + // If user got more than 2 strikes, it will upgrade to limited + if in.Type == models.PunishmentTypeStrike { + var count int64 + if err := database.C.Model(&models.Punishment{}). + Where("account_id = ? AND type = ?", in.AccountID, models.PunishmentTypeStrike). + Count(&count).Error; err != nil { + return in, err + } + if count > 2 { + in.Type = models.PunishmentTypeLimited + } + } + + if err := database.C.Create(&in).Error; err != nil { + return in, err + } + + return in, nil +} + +func EditPunishment(punishment models.Punishment) (models.Punishment, error) { + if err := database.C.Save(&punishment).Error; err != nil { + return punishment, err + } + return punishment, nil +} + +func DeletePunishment(punishment models.Punishment) error { + if err := database.C.Delete(&punishment).Error; err != nil { + return err + } + return nil +} + +func GetPunishment(id uint, preload ...bool) (models.Punishment, error) { + tx := database.C + if len(preload) > 0 && preload[0] { + tx = tx.Preload("Moderator").Preload("Account") + } + + var punishment models.Punishment + if err := tx.First(&punishment, id).Error; err != nil { + return punishment, err + } + return punishment, nil +} + +func GetMadePunishment(id uint, moderator models.Account) (models.Punishment, error) { + var punishment models.Punishment + if err := database.C.Where("id = ? AND moderator_id = ?", id, moderator.ID).First(&punishment).Error; err != nil { + return punishment, err + } + return punishment, nil +} + +func ListPunishments(user models.Account) ([]models.Punishment, error) { + var punishments []models.Punishment + if err := database.C. + Where("account_id = ? AND (expired_at IS NULL OR expired_at <= ?)", user.ID, time.Now()). + Preload("Moderator"). + Order("created_at DESC"). + Find(&punishments).Error; err != nil { + return nil, err + } + return punishments, nil +} + +func CountAllPunishments() (int64, error) { + var count int64 + if err := database.C. + Model(&models.Punishment{}). + Count(&count).Error; err != nil { + return 0, err + } + return count, nil +} + +func ListAllPunishments(take, offset int) ([]models.Punishment, error) { + var punishments []models.Punishment + if err := database.C. + Preload("Account"). + Preload("Moderator"). + Order("created_at DESC"). + Take(take).Offset(offset). + Find(&punishments).Error; err != nil { + return nil, err + } + return punishments, nil +} + +func CountMadePunishments(moderator models.Account) (int64, error) { + var count int64 + if err := database.C. + Model(&models.Punishment{}). + Where("moderator_id = ?", moderator.ID). + Count(&count).Error; err != nil { + return 0, err + } + return count, nil +} + +func ListMadePunishments(moderator models.Account, take, offset int) ([]models.Punishment, error) { + var punishments []models.Punishment + if err := database.C. + Where("moderator_id = ?", moderator.ID). + Preload("Account"). + Order("created_at DESC"). + Take(take).Offset(offset). + Find(&punishments).Error; err != nil { + return nil, err + } + return punishments, nil +} + +func CheckLoginAbility(user models.Account) error { + var punishments []models.Punishment + if err := database.C.Where("account_id = ? AND (expired_at IS NULL OR expired_at <= ?)", user.ID, time.Now()). + Find(&punishments).Error; err != nil { + return fmt.Errorf("failed to get punishments: %v", err) + } + + for _, punishment := range punishments { + if punishment.Type == models.PunishmentTypeDisabled { + return fmt.Errorf("account has been fully disabled due to: %s (case #%d)", punishment.Reason, punishment.ID) + } + // Limited punishment with no permissions override is fully limited + // Refer https://solsynth.dev/terms/basic-law#provision-and-discontinuation-of-services + if punishment.Type == models.PunishmentTypeLimited && len(punishment.PermNodes) == 0 { + return fmt.Errorf("account has been limited login due to: %s (case #%d)", punishment.Reason, punishment.ID) + } + } + + return nil +} diff --git a/pkg/internal/web/api/auth_api.go b/pkg/internal/web/api/auth_api.go index 0d2ec1e..2f5fe3d 100644 --- a/pkg/internal/web/api/auth_api.go +++ b/pkg/internal/web/api/auth_api.go @@ -39,6 +39,8 @@ func doAuthenticate(c *fiber.Ctx) error { return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("account was not found: %v", err.Error())) } else if user.SuspendedAt != nil { return fiber.NewError(fiber.StatusForbidden, "account was suspended") + } else if err := services.CheckLoginAbility(user); err != nil { + return err } ticket, err := services.NewTicket(user, c.IP(), c.Get(fiber.HeaderUserAgent)) diff --git a/pkg/internal/web/api/index.go b/pkg/internal/web/api/index.go index 113ce1a..786a40b 100644 --- a/pkg/internal/web/api/index.go +++ b/pkg/internal/web/api/index.go @@ -54,6 +54,16 @@ func MapControllers(app *fiber.App, baseURL string) { } } + punishments := api.Group("/punishments").Name("Punishments API") + { + punishments.Get("/", listUserPunishment) + punishments.Get("/given", listMadePunishment) + punishments.Get("/:id", getPunishment) + punishments.Post("/", createPunishment) + punishments.Put("/:id", editPunishment) + punishments.Delete("/:id", deletePunishment) + } + api.Get("/users", getUserInBatch) api.Get("/users/lookup", lookupAccount) api.Get("/users/search", searchAccount) diff --git a/pkg/internal/web/api/punishments_api.go b/pkg/internal/web/api/punishments_api.go new file mode 100644 index 0000000..5990b60 --- /dev/null +++ b/pkg/internal/web/api/punishments_api.go @@ -0,0 +1,180 @@ +package api + +import ( + "time" + + "git.solsynth.dev/hypernet/passport/pkg/authkit/models" + "git.solsynth.dev/hypernet/passport/pkg/internal/database" + "git.solsynth.dev/hypernet/passport/pkg/internal/services" + "git.solsynth.dev/hypernet/passport/pkg/internal/web/exts" + "github.com/gofiber/fiber/v2" +) + +func getPunishment(c *fiber.Ctx) error { + id, _ := c.ParamsInt("id") + data, err := services.GetPunishment(uint(id), true) + if err != nil { + return err + } + return c.JSON(data) +} + +func listUserPunishment(c *fiber.Ctx) error { + if err := exts.EnsureAuthenticated(c); err != nil { + return err + } + user := c.Locals("user").(models.Account) + + data, err := services.ListPunishments(user) + if err != nil { + return err + } + return c.JSON(data) +} + +func listMadePunishment(c *fiber.Ctx) error { + if err := exts.EnsureAuthenticated(c); err != nil { + return err + } + moderator := c.Locals("user").(models.Account) + + take := c.QueryInt("take", 0) + offset := c.QueryInt("offset", 0) + + if c.QueryBool("all", false) { + if err := exts.EnsureGrantedPerm(c, "OverridePunishments", true); err != nil { + return err + } + count, err := services.CountAllPunishments() + data, err := services.ListAllPunishments(take, offset) + if err != nil { + return err + } + return c.JSON(fiber.Map{ + "count": count, + "data": data, + }) + } + + count, err := services.CountMadePunishments(moderator) + data, err := services.ListMadePunishments(moderator, take, offset) + if err != nil { + return err + } + return c.JSON(fiber.Map{ + "count": count, + "data": data, + }) +} + +func createPunishment(c *fiber.Ctx) error { + if err := exts.EnsureGrantedPerm(c, "CreatePunishments", true); err != nil { + return err + } + user := c.Locals("user").(models.Account) + + var data struct { + Reason string `json:"reason" validate:"required"` + Type int `json:"type"` + ExpiredAt *time.Time `json:"expired_at"` + PermNodes map[string]any `json:"perm_nodes"` + AccountID uint `json:"account_id"` + } + + if err := exts.BindAndValidate(c, &data); err != nil { + return err + } + + var account models.Account + if err := database.C.Where("id = ?", data.AccountID).First(&account).Error; err != nil { + return fiber.NewError(fiber.StatusNotFound, err.Error()) + } + + punishment := models.Punishment{ + Reason: data.Reason, + Type: data.Type, + PermNodes: data.PermNodes, + ExpiredAt: data.ExpiredAt, + Account: account, + AccountID: account.ID, + } + + if punishment, err := services.NewPunishment(punishment, user); err != nil { + return err + } else { + return c.JSON(punishment) + } +} + +func editPunishment(c *fiber.Ctx) error { + if err := exts.EnsureAuthenticated(c); err != nil { + return err + } + user := c.Locals("user").(models.Account) + + id, _ := c.ParamsInt("id", 0) + + var data struct { + Reason string `json:"reason" validate:"required"` + Type int `json:"type"` + ExpiredAt *time.Time `json:"expired_at"` + PermNodes map[string]any `json:"perm_nodes"` + } + + if err := exts.BindAndValidate(c, &data); err != nil { + return err + } + + var err error + var punishment models.Punishment + if c.QueryBool("override", false) { + if err = exts.EnsureGrantedPerm(c, "OverridePunishments", true); err != nil { + return err + } + punishment, err = services.GetPunishment(uint(id)) + } else { + punishment, err = services.GetMadePunishment(uint(id), user) + } + if err != nil { + return fiber.NewError(fiber.StatusNotFound, err.Error()) + } + + punishment.Reason = data.Reason + punishment.Type = data.Type + punishment.ExpiredAt = data.ExpiredAt + punishment.PermNodes = data.PermNodes + + if punishment, err := services.EditPunishment(punishment); err != nil { + return err + } else { + return c.JSON(punishment) + } +} + +func deletePunishment(c *fiber.Ctx) error { + if err := exts.EnsureAuthenticated(c); err != nil { + return err + } + user := c.Locals("user").(models.Account) + + id := c.QueryInt("id") + + var err error + var punishment models.Punishment + if c.QueryBool("override", false) { + if err = exts.EnsureGrantedPerm(c, "OverridePunishments", true); err != nil { + return err + } + punishment, err = services.GetPunishment(uint(id)) + } else { + punishment, err = services.GetMadePunishment(uint(id), user) + } + if err != nil { + return fiber.NewError(fiber.StatusNotFound, err.Error()) + } + + if err := services.DeletePunishment(punishment); err != nil { + return err + } + return c.SendStatus(fiber.StatusOK) +}