diff --git a/pkg/server/channels_api.go b/pkg/server/channels_api.go index 44bcaec..e909ea2 100644 --- a/pkg/server/channels_api.go +++ b/pkg/server/channels_api.go @@ -43,7 +43,13 @@ func listChannel(c *fiber.Ctx) error { func listOwnedChannel(c *fiber.Ctx) error { user := c.Locals("principal").(models.Account) - channels, err := services.ListChannelWithUser(user) + var err error + var channels []models.Channel + if val, ok := c.Locals("realm").(models.Realm); ok { + channels, err = services.ListChannelWithUser(user, val.ID) + } else { + channels, err = services.ListChannelWithUser(user) + } if err != nil { return fiber.NewError(fiber.StatusBadRequest, err.Error()) } @@ -54,7 +60,13 @@ func listOwnedChannel(c *fiber.Ctx) error { func listAvailableChannel(c *fiber.Ctx) error { user := c.Locals("principal").(models.Account) - channels, err := services.ListChannelIsAvailable(user) + var err error + var channels []models.Channel + if val, ok := c.Locals("realm").(models.Realm); ok { + channels, err = services.ListAvailableChannel(user, val.ID) + } else { + channels, err = services.ListAvailableChannel(user) + } if err != nil { return fiber.NewError(fiber.StatusBadRequest, err.Error()) } diff --git a/pkg/server/messages_api.go b/pkg/server/messages_api.go index ebdf2e3..faa8763 100644 --- a/pkg/server/messages_api.go +++ b/pkg/server/messages_api.go @@ -55,9 +55,19 @@ func newMessage(c *fiber.Ctx) error { return fmt.Errorf("you must write or upload some content in a single message") } - channel, member, err := services.GetAvailableChannelWithAlias(alias, user) - if err != nil { - return fiber.NewError(fiber.StatusNotFound, err.Error()) + var err error + var channel models.Channel + var member models.ChannelMember + if val, ok := c.Locals("realm").(models.Realm); ok { + channel, member, err = services.GetAvailableChannelWithAlias(alias, user, val.ID) + if err != nil { + return fiber.NewError(fiber.StatusNotFound, err.Error()) + } + } else { + channel, member, err = services.GetAvailableChannelWithAlias(alias, user) + if err != nil { + return fiber.NewError(fiber.StatusNotFound, err.Error()) + } } message := models.Message{ @@ -102,17 +112,30 @@ func editMessage(c *fiber.Ctx) error { return err } + var err error + var channel models.Channel + var member models.ChannelMember + if val, ok := c.Locals("realm").(models.Realm); ok { + channel, member, err = services.GetAvailableChannelWithAlias(alias, user, val.ID) + if err != nil { + return fiber.NewError(fiber.StatusNotFound, err.Error()) + } + } else { + channel, member, err = services.GetAvailableChannelWithAlias(alias, user) + if err != nil { + return fiber.NewError(fiber.StatusNotFound, err.Error()) + } + } + var message models.Message - if channel, member, err := services.GetAvailableChannelWithAlias(alias, user); err != nil { - return fiber.NewError(fiber.StatusNotFound, err.Error()) - } else if message, err = services.GetMessageWithPrincipal(channel, member, uint(messageId)); err != nil { + if message, err = services.GetMessageWithPrincipal(channel, member, uint(messageId)); err != nil { return fiber.NewError(fiber.StatusNotFound, err.Error()) } message.Content = data.Content message.Attachments = data.Attachments - message, err := services.EditMessage(message) + message, err = services.EditMessage(message) if err != nil { return fiber.NewError(fiber.StatusBadRequest, err.Error()) } @@ -125,14 +148,27 @@ func deleteMessage(c *fiber.Ctx) error { alias := c.Params("channel") messageId, _ := c.ParamsInt("messageId", 0) + var err error + var channel models.Channel + var member models.ChannelMember + if val, ok := c.Locals("realm").(models.Realm); ok { + channel, member, err = services.GetAvailableChannelWithAlias(alias, user, val.ID) + if err != nil { + return fiber.NewError(fiber.StatusNotFound, err.Error()) + } + } else { + channel, member, err = services.GetAvailableChannelWithAlias(alias, user) + if err != nil { + return fiber.NewError(fiber.StatusNotFound, err.Error()) + } + } + var message models.Message - if channel, member, err := services.GetAvailableChannelWithAlias(alias, user); err != nil { - return fiber.NewError(fiber.StatusNotFound, err.Error()) - } else if message, err = services.GetMessageWithPrincipal(channel, member, uint(messageId)); err != nil { + if message, err = services.GetMessageWithPrincipal(channel, member, uint(messageId)); err != nil { return fiber.NewError(fiber.StatusNotFound, err.Error()) } - message, err := services.DeleteMessage(message) + message, err = services.DeleteMessage(message) if err != nil { return fiber.NewError(fiber.StatusBadRequest, err.Error()) } diff --git a/pkg/services/channels.go b/pkg/services/channels.go index 1b5bae4..a68c21a 100644 --- a/pkg/services/channels.go +++ b/pkg/services/channels.go @@ -42,11 +42,11 @@ func GetChannelWithAlias(alias string, realmId ...uint) (models.Channel, error) return channel, nil } -func GetAvailableChannelWithAlias(alias string, user models.Account) (models.Channel, models.ChannelMember, error) { +func GetAvailableChannelWithAlias(alias string, user models.Account, realmId ...uint) (models.Channel, models.ChannelMember, error) { var err error var member models.ChannelMember var channel models.Channel - if channel, err = GetChannelWithAlias(alias); err != nil { + if channel, err = GetChannelWithAlias(alias, realmId...); err != nil { return channel, member, err } @@ -93,16 +93,22 @@ func ListChannel(realmId ...uint) ([]models.Channel, error) { return channels, nil } -func ListChannelWithUser(user models.Account) ([]models.Channel, error) { +func ListChannelWithUser(user models.Account, realmId ...uint) ([]models.Channel, error) { var channels []models.Channel - if err := database.C.Where(&models.Channel{AccountID: user.ID}).Find(&channels).Error; err != nil { + tx := database.C.Where(&models.Channel{AccountID: user.ID}) + if len(realmId) > 0 { + tx = tx.Where("realm_id = ?", realmId) + } else { + tx = tx.Where("realm_id IS NULL") + } + if err := tx.Find(&channels).Error; err != nil { return channels, err } return channels, nil } -func ListChannelIsAvailable(user models.Account) ([]models.Channel, error) { +func ListAvailableChannel(user models.Account, realmId ...uint) ([]models.Channel, error) { var channels []models.Channel var members []models.ChannelMember if err := database.C.Where(&models.ChannelMember{ @@ -115,7 +121,13 @@ func ListChannelIsAvailable(user models.Account) ([]models.Channel, error) { return item.ChannelID }) - if err := database.C.Where("id IN ?", idx).Find(&channels).Error; err != nil { + tx := database.C.Where("id IN ?", idx) + if len(realmId) > 0 { + tx = tx.Where("realm_id = ?", realmId) + } else { + tx = tx.Where("realm_id IS NULL") + } + if err := tx.Find(&channels).Error; err != nil { return channels, err }