From 23450c06902dbc372e94d282562d8e03f502410f Mon Sep 17 00:00:00 2001 From: LittleSheep Date: Sat, 1 Jun 2024 10:43:21 +0800 Subject: [PATCH] :bug: Bug fixes --- pkg/server/channels_api.go | 12 ++++++------ pkg/server/startup.go | 2 +- pkg/services/channels.go | 19 +++++++++---------- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/pkg/server/channels_api.go b/pkg/server/channels_api.go index 0ce29ef..a51660a 100644 --- a/pkg/server/channels_api.go +++ b/pkg/server/channels_api.go @@ -26,22 +26,22 @@ func getChannel(c *fiber.Ctx) error { return c.JSON(channel) } -func getChannelAvailability(c *fiber.Ctx) error { +func getChannelIdentity(c *fiber.Ctx) error { user := c.Locals("principal").(models.Account) alias := c.Params("channel") var err error - var channel models.Channel + var member models.ChannelMember if val, ok := c.Locals("realm").(models.Realm); ok { - channel, _, err = services.GetAvailableChannelWithAlias(alias, user, val.ID) + _, member, err = services.GetAvailableChannelWithAlias(alias, user, val.ID) } else { - channel, _, err = services.GetAvailableChannelWithAlias(alias, user) + _, member, err = services.GetAvailableChannelWithAlias(alias, user) } if err != nil { - return c.Status(fiber.StatusForbidden).JSON(channel) + return c.SendStatus(fiber.StatusForbidden) } - return c.JSON(channel) + return c.JSON(member) } func listChannel(c *fiber.Ctx) error { diff --git a/pkg/server/startup.go b/pkg/server/startup.go index 940fd76..398350a 100644 --- a/pkg/server/startup.go +++ b/pkg/server/startup.go @@ -72,7 +72,7 @@ func NewServer() { channels.Get("/me", authMiddleware, listOwnedChannel) channels.Get("/me/available", authMiddleware, listAvailableChannel) channels.Get("/:channel", getChannel) - channels.Get("/:channel/availability", authMiddleware, getChannelAvailability) + channels.Get("/:channel/me", authMiddleware, getChannelIdentity) channels.Post("/", authMiddleware, createChannel) channels.Post("/dm", authMiddleware, createDirectChannel) diff --git a/pkg/services/channels.go b/pkg/services/channels.go index e682be8..ac355e0 100644 --- a/pkg/services/channels.go +++ b/pkg/services/channels.go @@ -20,9 +20,11 @@ func GetChannelAliasAvailability(alias string) error { func GetChannel(id uint) (models.Channel, error) { var channel models.Channel - if err := database.C.Where(models.Channel{ + tx := database.C.Where(models.Channel{ BaseModel: models.BaseModel{ID: id}, - }).Preload("Account").First(&channel).Error; err != nil { + }).Preload("Account").Preload("Realm") + tx = PreloadDirectChannelMembers(tx) + if err := tx.First(&channel).Error; err != nil { return channel, err } @@ -31,12 +33,13 @@ func GetChannel(id uint) (models.Channel, error) { func GetChannelWithAlias(alias string, realmId ...uint) (models.Channel, error) { var channel models.Channel - tx := database.C.Where(models.Channel{Alias: alias}).Preload("Account") + tx := database.C.Where(models.Channel{Alias: alias}).Preload("Account").Preload("Realm") if len(realmId) > 0 { tx = tx.Where("realm_id = ?", realmId) } else { tx = tx.Where("realm_id IS NULL") } + tx = PreloadDirectChannelMembers(tx) if err := tx.First(&channel).Error; err != nil { return channel, err } @@ -69,11 +72,11 @@ func GetAvailableChannel(id uint, user models.Account) (models.Channel, models.C if channel, err = GetChannel(id); err != nil { return channel, member, err } - - if err := database.C.Where(models.ChannelMember{ + tx := database.C.Where(models.ChannelMember{ AccountID: user.ID, ChannelID: channel.ID, - }).First(&member).Error; err != nil { + }) + if err := tx.First(&member).Error; err != nil { return channel, member, fmt.Errorf("channel principal not found: %v", err.Error()) } @@ -113,8 +116,6 @@ func ListChannelWithUser(user models.Account, realmId ...uint) ([]models.Channel tx := database.C.Where(&models.Channel{AccountID: user.ID}).Preload("Realm") if len(realmId) > 0 { tx = tx.Where("realm_id = ?", realmId) - } else { - tx = tx.Where("realm_id IS NULL") } tx = PreloadDirectChannelMembers(tx) @@ -142,8 +143,6 @@ func ListAvailableChannel(user models.Account, realmId ...uint) ([]models.Channe tx := database.C.Preload("Realm").Where("id IN ?", idx) if len(realmId) > 0 { tx = tx.Where("realm_id = ?", realmId) - } else { - tx = tx.Where("realm_id IS NULL") } tx = PreloadDirectChannelMembers(tx)