diff --git a/pkg/internal/server/api/channels_api.go b/pkg/internal/server/api/channels_api.go index 6891fec..39450f9 100644 --- a/pkg/internal/server/api/channels_api.go +++ b/pkg/internal/server/api/channels_api.go @@ -40,11 +40,13 @@ func getChannelIdentity(c *fiber.Ctx) error { var err error var member models.ChannelMember + if val, ok := c.Locals("realm").(models.Realm); ok { - _, member, err = services.GetAvailableChannelWithAlias(alias, user, val.ID) + _, member, err = services.GetChannelIdentity(alias, user.ID, val) } else { - _, member, err = services.GetAvailableChannelWithAlias(alias, user) + _, member, err = services.GetChannelIdentity(alias, user.ID) } + if err != nil { return c.SendStatus(fiber.StatusForbidden) } diff --git a/pkg/internal/server/api/events_api.go b/pkg/internal/server/api/events_api.go index 1d106ab..bf360df 100644 --- a/pkg/internal/server/api/events_api.go +++ b/pkg/internal/server/api/events_api.go @@ -95,20 +95,17 @@ func newRawEvent(c *fiber.Ctx) error { var err error var channel models.Channel var member models.ChannelMember + if val, ok := c.Locals("realm").(models.Realm); ok { - channel, member, err = services.GetAvailableChannelWithAlias(alias, user, val.ID) - if err != nil { - return fiber.NewError(fiber.StatusNotFound, err.Error()) - } else if member.PowerLevel < 0 { - return fiber.NewError(fiber.StatusForbidden, "you have not enough permission to send message") - } + channel, member, err = services.GetChannelIdentity(alias, user.ID, val) } else { - channel, member, err = services.GetAvailableChannelWithAlias(alias, user) - if err != nil { - return fiber.NewError(fiber.StatusNotFound, err.Error()) - } else if member.PowerLevel < 0 { - return fiber.NewError(fiber.StatusForbidden, "you have not enough permission to send message") - } + channel, member, err = services.GetChannelIdentity(alias, user.ID) + } + + if err != nil { + return fiber.NewError(fiber.StatusNotFound, err.Error()) + } else if member.PowerLevel < 0 { + return fiber.NewError(fiber.StatusForbidden, "you have not enough permission to send message") } event := models.Event{ diff --git a/pkg/internal/server/api/events_message_api.go b/pkg/internal/server/api/events_message_api.go index d632397..4b858da 100644 --- a/pkg/internal/server/api/events_message_api.go +++ b/pkg/internal/server/api/events_message_api.go @@ -38,11 +38,13 @@ func newMessageEvent(c *fiber.Ctx) error { var err error var channel models.Channel var member models.ChannelMember + if val, ok := c.Locals("realm").(models.Realm); ok { - channel, member, err = services.GetAvailableChannelWithAlias(alias, user, val.ID) + channel, member, err = services.GetChannelIdentity(alias, user.ID, val) } else { - channel, member, err = services.GetAvailableChannelWithAlias(alias, user) + channel, member, err = services.GetChannelIdentity(alias, user.ID) } + if err != nil { return fiber.NewError(fiber.StatusNotFound, err.Error()) } else if member.PowerLevel < 0 { @@ -95,11 +97,13 @@ func editMessageEvent(c *fiber.Ctx) error { var err error var channel models.Channel var member models.ChannelMember + if val, ok := c.Locals("realm").(models.Realm); ok { - channel, member, err = services.GetAvailableChannelWithAlias(alias, user, val.ID) + channel, member, err = services.GetChannelIdentity(alias, user.ID, val) } else { - channel, member, err = services.GetAvailableChannelWithAlias(alias, user) + channel, member, err = services.GetChannelIdentity(alias, user.ID) } + if err != nil { return fiber.NewError(fiber.StatusNotFound, err.Error()) } @@ -128,16 +132,15 @@ func deleteMessageEvent(c *fiber.Ctx) error { var err error var channel models.Channel var member models.ChannelMember + if val, ok := c.Locals("realm").(models.Realm); ok { - channel, member, err = services.GetAvailableChannelWithAlias(alias, user, val.ID) - if err != nil { - return fiber.NewError(fiber.StatusNotFound, err.Error()) - } + channel, member, err = services.GetChannelIdentity(alias, user.ID, val) } else { - channel, member, err = services.GetAvailableChannelWithAlias(alias, user) - if err != nil { - return fiber.NewError(fiber.StatusNotFound, err.Error()) - } + channel, member, err = services.GetChannelIdentity(alias, user.ID) + } + + if err != nil { + return fiber.NewError(fiber.StatusNotFound, err.Error()) } var event models.Event diff --git a/pkg/internal/services/channel_members.go b/pkg/internal/services/channel_members.go index d2a9675..9069a28 100644 --- a/pkg/internal/services/channel_members.go +++ b/pkg/internal/services/channel_members.go @@ -1,7 +1,12 @@ package services import ( + "context" "fmt" + localCache "git.solsynth.dev/hydrogen/messaging/pkg/internal/cache" + "github.com/eko/gocache/lib/v4/cache" + "github.com/eko/gocache/lib/v4/marshaler" + "github.com/eko/gocache/lib/v4/store" "git.solsynth.dev/hydrogen/messaging/pkg/internal/database" "git.solsynth.dev/hydrogen/messaging/pkg/internal/models" @@ -53,13 +58,41 @@ func AddChannelMember(user models.Account, target models.Channel) error { } err := database.C.Save(&member).Error + + if err == nil { + cacheManager := cache.New[any](localCache.S) + marshal := marshaler.New(cacheManager) + contx := context.Background() + + _ = marshal.Invalidate( + contx, + store.WithInvalidateTags([]string{ + fmt.Sprintf("channel#%d", target.ID), + fmt.Sprintf("user#%d", user.ID), + }), + ) + } + return err } func EditChannelMember(membership models.ChannelMember) (models.ChannelMember, error) { if err := database.C.Save(&membership).Error; err != nil { return membership, err + } else { + cacheManager := cache.New[any](localCache.S) + marshal := marshaler.New(cacheManager) + contx := context.Background() + + _ = marshal.Invalidate( + contx, + store.WithInvalidateTags([]string{ + fmt.Sprintf("channel#%d", membership.ChannelID), + fmt.Sprintf("user#%d", membership.AccountID), + }), + ) } + return membership, nil } @@ -73,5 +106,23 @@ func RemoveChannelMember(user models.Account, target models.Channel) error { return err } - return database.C.Delete(&member).Error + if err := database.C.Delete(&member).Error; err == nil { + database.C.Where("sender_id = ?").Delete(&models.Event{}) + + cacheManager := cache.New[any](localCache.S) + marshal := marshaler.New(cacheManager) + contx := context.Background() + + _ = marshal.Invalidate( + contx, + store.WithInvalidateTags([]string{ + fmt.Sprintf("channel#%d", target.ID), + fmt.Sprintf("user#%d", user.ID), + }), + ) + + return nil + } else { + return err + } } diff --git a/pkg/internal/services/channels.go b/pkg/internal/services/channels.go index 0a1f79f..3be6430 100644 --- a/pkg/internal/services/channels.go +++ b/pkg/internal/services/channels.go @@ -1,7 +1,12 @@ package services import ( + "context" "fmt" + localCache "git.solsynth.dev/hydrogen/messaging/pkg/internal/cache" + "github.com/eko/gocache/lib/v4/cache" + "github.com/eko/gocache/lib/v4/marshaler" + "github.com/eko/gocache/lib/v4/store" "regexp" "git.solsynth.dev/hydrogen/dealer/pkg/hyper" @@ -12,6 +17,72 @@ import ( "gorm.io/gorm" ) +type channelIdentityCacheEntry struct { + Channel models.Channel + ChannelMember models.ChannelMember +} + +func GetChannelIdentityCacheKey(channel string, user uint, realm ...uint) string { + if len(realm) > 0 { + return fmt.Sprintf("channel-identity-%s#%d@%d", channel, user, realm) + } else { + return fmt.Sprintf("channel-identity-%s#%d", channel, user) + } +} + +func CacheChannelIdentityCache(channel models.Channel, member models.ChannelMember, user uint, realm ...uint) { + key := GetChannelIdentityCacheKey(channel.Alias, user, realm...) + + cacheManager := cache.New[any](localCache.S) + marshal := marshaler.New(cacheManager) + contx := context.Background() + + _ = marshal.Set( + contx, + key, + channelIdentityCacheEntry{channel, member}, + store.WithTags([]string{"channel-identity", fmt.Sprintf("channel#%d", channel.ID), fmt.Sprintf("user#%d", user)}), + ) +} + +func GetChannelIdentity(alias string, user uint, realm ...models.Realm) (models.Channel, models.ChannelMember, error) { + cacheManager := cache.New[any](localCache.S) + marshal := marshaler.New(cacheManager) + contx := context.Background() + + var err error + var channel models.Channel + var member models.ChannelMember + + hitCache := false + if len(realm) > 0 { + if val, err := marshal.Get(contx, GetChannelIdentityCacheKey(alias, user, realm[0].ID), new(channelIdentityCacheEntry)); err == nil { + entry := val.(*channelIdentityCacheEntry) + channel = entry.Channel + member = entry.ChannelMember + hitCache = true + } + } else { + if val, err := marshal.Get(contx, GetChannelIdentityCacheKey(alias, user), new(channelIdentityCacheEntry)); err == nil { + entry := val.(*channelIdentityCacheEntry) + channel = entry.Channel + member = entry.ChannelMember + hitCache = true + } + } + if !hitCache { + if len(realm) > 0 { + channel, member, err = GetAvailableChannelWithAlias(alias, user, realm[0].ID) + CacheChannelIdentityCache(channel, member, user, realm[0].ID) + } else { + channel, member, err = GetAvailableChannelWithAlias(alias, user) + CacheChannelIdentityCache(channel, member, user) + } + } + + return channel, member, err +} + func GetChannelAliasAvailability(alias string) error { if !regexp.MustCompile("^[a-z0-9-]+$").MatchString(alias) { return fmt.Errorf("channel alias should only contains lowercase letters, numbers, and hyphens") @@ -48,7 +119,7 @@ func GetChannelWithAlias(alias string, realmId ...uint) (models.Channel, error) return channel, nil } -func GetAvailableChannelWithAlias(alias string, user models.Account, realmId ...uint) (models.Channel, models.ChannelMember, error) { +func GetAvailableChannelWithAlias(alias string, user uint, realmId ...uint) (models.Channel, models.ChannelMember, error) { var err error var member models.ChannelMember var channel models.Channel @@ -57,7 +128,7 @@ func GetAvailableChannelWithAlias(alias string, user models.Account, realmId ... } if err := database.C.Where(models.ChannelMember{ - AccountID: user.ID, + AccountID: user, ChannelID: channel.ID, }).First(&member).Error; err != nil { return channel, member, fmt.Errorf("channel principal not found: %v", err.Error()) @@ -183,9 +254,35 @@ func EditChannel(channel models.Channel, alias, name, description string, isPubl err := database.C.Save(&channel).Error + if err == nil { + cacheManager := cache.New[any](localCache.S) + marshal := marshaler.New(cacheManager) + contx := context.Background() + + _ = marshal.Invalidate( + contx, + store.WithInvalidateTags([]string{fmt.Sprintf("channel#%d", channel.ID)}), + ) + } + return channel, err } func DeleteChannel(channel models.Channel) error { - return database.C.Delete(&channel).Error + if err := database.C.Delete(&channel).Error; err == nil { + database.C.Where("channel_id = ?", channel.ID).Delete(&models.Event{}) + + cacheManager := cache.New[any](localCache.S) + marshal := marshaler.New(cacheManager) + contx := context.Background() + + _ = marshal.Invalidate( + contx, + store.WithInvalidateTags([]string{fmt.Sprintf("channel#%d", channel.ID)}), + ) + + return nil + } else { + return err + } } diff --git a/pkg/internal/services/status.go b/pkg/internal/services/status.go index cae2070..f9b3b6d 100644 --- a/pkg/internal/services/status.go +++ b/pkg/internal/services/status.go @@ -34,7 +34,7 @@ func SetTypingStatus(channelId uint, userId uint) error { hitCache := false if val, err := marshal.Get(contx, GetTypingStatusQueryCacheKey(channelId, userId), new(statusQueryCacheEntry)); err == nil { - entry := val.(statusQueryCacheEntry) + entry := val.(*statusQueryCacheEntry) broadcastTarget = entry.Target data = entry.Data hitCache = true