diff --git a/.idea/workspace.xml b/.idea/workspace.xml
index 9186299..d783639 100644
--- a/.idea/workspace.xml
+++ b/.idea/workspace.xml
@@ -4,10 +4,10 @@
-
+
-
-
+
+
@@ -151,7 +151,8 @@
-
+
+
true
diff --git a/pkg/models/notifications.go b/pkg/models/notifications.go
index 5daf4dd..c1302bf 100644
--- a/pkg/models/notifications.go
+++ b/pkg/models/notifications.go
@@ -31,8 +31,9 @@ const (
type NotificationSubscriber struct {
BaseModel
- UserAgent string `json:"user_agent"`
- Provider string `json:"provider"`
- DeviceID string `json:"device_id" gorm:"uniqueIndex"`
- AccountID uint `json:"account_id"`
+ UserAgent string `json:"user_agent"`
+ Provider string `json:"provider"`
+ DeviceID string `json:"device_id" gorm:"uniqueIndex"`
+ DeviceToken string `json:"device_token"`
+ AccountID uint `json:"account_id"`
}
diff --git a/pkg/server/notifications_api.go b/pkg/server/notifications_api.go
index 4ad4f5c..fbaa3ac 100644
--- a/pkg/server/notifications_api.go
+++ b/pkg/server/notifications_api.go
@@ -89,8 +89,9 @@ func addNotifySubscriber(c *fiber.Ctx) error {
user := c.Locals("principal").(models.Account)
var data struct {
- Provider string `json:"provider" validate:"required"`
- DeviceID string `json:"device_id" validate:"required"`
+ Provider string `json:"provider" validate:"required"`
+ DeviceToken string `json:"device_token" validate:"required"`
+ DeviceID string `json:"device_id" validate:"required"`
}
if err := utils.BindAndValidate(c, &data); err != nil {
@@ -99,8 +100,9 @@ func addNotifySubscriber(c *fiber.Ctx) error {
var count int64
if err := database.C.Where(&models.NotificationSubscriber{
- DeviceID: data.DeviceID,
- AccountID: user.ID,
+ DeviceID: data.DeviceID,
+ DeviceToken: data.DeviceToken,
+ AccountID: user.ID,
}).Model(&models.NotificationSubscriber{}).Count(&count).Error; err != nil || count > 0 {
return c.SendStatus(fiber.StatusOK)
}
@@ -109,6 +111,7 @@ func addNotifySubscriber(c *fiber.Ctx) error {
user,
data.Provider,
data.DeviceID,
+ data.DeviceToken,
c.Get(fiber.HeaderUserAgent),
)
if err != nil {
diff --git a/pkg/services/notifications.go b/pkg/services/notifications.go
index ceb2a1a..8f9de11 100644
--- a/pkg/services/notifications.go
+++ b/pkg/services/notifications.go
@@ -3,6 +3,7 @@ package services
import (
"context"
jsoniter "github.com/json-iterator/go"
+ "reflect"
"firebase.google.com/go/messaging"
"git.solsynth.dev/hydrogen/passport/pkg/database"
@@ -11,15 +12,32 @@ import (
"github.com/rs/zerolog/log"
)
-func AddNotifySubscriber(user models.Account, provider, device, ua string) (models.NotificationSubscriber, error) {
- subscriber := models.NotificationSubscriber{
- UserAgent: ua,
- Provider: provider,
- DeviceID: device,
+func AddNotifySubscriber(user models.Account, provider, id, tk, ua string) (models.NotificationSubscriber, error) {
+ var prev models.NotificationSubscriber
+ var subscriber models.NotificationSubscriber
+ if err := database.C.Where(&models.NotificationSubscriber{
+ DeviceID: id,
AccountID: user.ID,
+ }); err != nil {
+ subscriber = models.NotificationSubscriber{
+ UserAgent: ua,
+ Provider: provider,
+ DeviceID: id,
+ DeviceToken: tk,
+ AccountID: user.ID,
+ }
+ } else {
+ prev = subscriber
}
- err := database.C.Save(&subscriber).Error
+ subscriber.UserAgent = ua
+ subscriber.Provider = provider
+ subscriber.DeviceToken = tk
+
+ var err error
+ if !reflect.DeepEqual(subscriber, prev) {
+ err = database.C.Save(&subscriber).Error
+ }
return subscriber, err
}
@@ -72,7 +90,7 @@ func PushNotification(notification models.Notification) error {
Title: notification.Subject,
Body: notification.Content,
},
- Token: subscriber.DeviceID,
+ Token: subscriber.DeviceToken,
}
if response, err := client.Send(ctx, message); err != nil {