diff --git a/pkg/models/messages.go b/pkg/models/messages.go index 05da36a..2535e14 100644 --- a/pkg/models/messages.go +++ b/pkg/models/messages.go @@ -7,7 +7,6 @@ type MessageType = uint8 const ( MessageTypeText = MessageType(iota) MessageTypeAudio - MessageTypeFile ) type Message struct { @@ -19,6 +18,8 @@ type Message struct { Attachments []Attachment `json:"attachments"` Channel Channel `json:"channel"` Sender ChannelMember `json:"sender"` + ReplyID *uint `json:"reply_id"` + ReplyTo *Message `json:"reply_to" gorm:"foreignKey:ReplyID"` ChannelID uint `json:"channel_id"` SenderID uint `json:"sender_id"` } diff --git a/pkg/server/messages_api.go b/pkg/server/messages_api.go index 6180cf9..9ccc8e1 100644 --- a/pkg/server/messages_api.go +++ b/pkg/server/messages_api.go @@ -1,6 +1,8 @@ package server import ( + "fmt" + "git.solsynth.dev/hydrogen/messaging/pkg/database" "git.solsynth.dev/hydrogen/messaging/pkg/models" "git.solsynth.dev/hydrogen/messaging/pkg/services" "github.com/gofiber/fiber/v2" @@ -35,16 +37,40 @@ func newTextMessage(c *fiber.Ctx) error { var data struct { Content string `json:"content" validate:"required"` Attachments []models.Attachment `json:"attachments"` + ReplyTo *uint `json:"reply_to"` } if err := BindAndValidate(c, &data); err != nil { return err } - var message models.Message - if channel, member, err := services.GetAvailableChannelWithAlias(alias, user); err != nil { + channel, member, err := services.GetAvailableChannelWithAlias(alias, user) + if err != nil { return fiber.NewError(fiber.StatusNotFound, err.Error()) - } else if message, err = services.NewTextMessage(data.Content, member, channel, data.Attachments...); err != nil { + } + + message := models.Message{ + Content: data.Content, + Metadata: nil, + Sender: member, + Channel: channel, + ChannelID: channel.ID, + SenderID: member.ID, + Attachments: data.Attachments, + Type: models.MessageTypeText, + } + + var replyTo models.Message + if data.ReplyTo != nil { + if err := database.C.Where("id = ?", data.ReplyTo).First(&replyTo).Error; err != nil { + return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("message to reply was not found: %v", err)) + } else { + message.ReplyTo = &replyTo + message.ReplyID = &replyTo.ID + } + } + + if message, err = services.NewMessage(message); err != nil { return fiber.NewError(fiber.StatusBadRequest, err.Error()) } diff --git a/pkg/services/messages.go b/pkg/services/messages.go index d81134f..80c21b8 100644 --- a/pkg/services/messages.go +++ b/pkg/services/messages.go @@ -28,6 +28,7 @@ func ListMessage(channel models.Channel, take int, offset int) ([]models.Message }).Limit(take).Offset(offset). Order("created_at DESC"). Preload("Attachments"). + Preload("ReplyTo"). Preload("Sender"). Preload("Sender.Account"). Find(&messages).Error; err != nil { @@ -44,6 +45,7 @@ func GetMessage(channel models.Channel, id uint) (models.Message, error) { BaseModel: models.BaseModel{ID: id}, ChannelID: channel.ID, }). + Preload("ReplyTo"). Preload("Attachments"). Preload("Sender"). Preload("Sender.Account"). @@ -67,24 +69,15 @@ func GetMessageWithPrincipal(channel models.Channel, member models.ChannelMember } } -func NewTextMessage(content string, sender models.ChannelMember, channel models.Channel, attachments ...models.Attachment) (models.Message, error) { - message := models.Message{ - Content: content, - Metadata: nil, - ChannelID: channel.ID, - SenderID: sender.ID, - Attachments: attachments, - Type: models.MessageTypeText, - } - +func NewMessage(message models.Message) (models.Message, error) { var members []models.ChannelMember if err := database.C.Save(&message).Error; err != nil { return message, err } else if err = database.C.Where(models.ChannelMember{ - ChannelID: channel.ID, + ChannelID: message.ChannelID, }).Find(&members).Error; err == nil { for _, member := range members { - message, _ = GetMessage(channel, message.ID) + message, _ = GetMessage(message.Channel, message.ID) PushCommand(member.AccountID, models.UnifiedCommand{ Action: "messages.new", Payload: message,