🎉 Initial Commit of Ring
This commit is contained in:
130
pkg/ring/clients/email.go
Normal file
130
pkg/ring/clients/email.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package clients
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/smtp"
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type SMTPSettings struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Encryption string // "tls", "ssl", "none"
|
||||
FromAddress string
|
||||
FromName string
|
||||
}
|
||||
|
||||
var smtpSettings SMTPSettings
|
||||
|
||||
func InitSMTPSettings() {
|
||||
smtpSettings = SMTPSettings{
|
||||
Host: viper.GetString("smtp.host"),
|
||||
Port: viper.GetInt("smtp.port"),
|
||||
Username: viper.GetString("smtp.username"),
|
||||
Password: viper.GetString("smtp.password"),
|
||||
Encryption: viper.GetString("smtp.encryption"),
|
||||
FromAddress: viper.GetString("smtp.from_address"),
|
||||
FromName: viper.GetString("smtp.from_name"),
|
||||
}
|
||||
|
||||
if smtpSettings.Host == "" || smtpSettings.Port == 0 || smtpSettings.FromAddress == "" {
|
||||
log.Warn().Msg("SMTP configuration incomplete. Email sending may not work.")
|
||||
} else {
|
||||
log.Info().Msgf("SMTP client initialized for %s:%d", smtpSettings.Host, smtpSettings.Port)
|
||||
}
|
||||
}
|
||||
|
||||
func GetSMTPSettings() SMTPSettings {
|
||||
return smtpSettings
|
||||
}
|
||||
|
||||
// SendEmail sends an email using the configured SMTP settings.
|
||||
func SendEmail(toAddress, subject, body string) error {
|
||||
if smtpSettings.Host == "" {
|
||||
return fmt.Errorf("SMTP client not initialized or host is empty")
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", smtpSettings.Host, smtpSettings.Port)
|
||||
|
||||
// Setup authentication
|
||||
auth := smtp.PlainAuth("", smtpSettings.Username, smtpSettings.Password, smtpSettings.Host)
|
||||
|
||||
// Construct the email message
|
||||
mime := "MIME-version: 1.0;\nContent-Type: text/html; charset=\"UTF-8\";\n"
|
||||
msg := []byte("From: " + smtpSettings.FromName + " <" + smtpSettings.FromAddress + ">\r\n" +
|
||||
"To: " + toAddress + "\r\n" +
|
||||
"Subject: " + subject + "\r\n" +
|
||||
mime + "\r\n" +
|
||||
body + "\r\n")
|
||||
|
||||
var err error
|
||||
switch smtpSettings.Encryption {
|
||||
case "tls":
|
||||
// TLS encryption
|
||||
err = sendMailTLS(addr, auth, smtpSettings.FromAddress, []string{toAddress}, msg)
|
||||
case "ssl":
|
||||
// SSL encryption (usually port 465)
|
||||
err = smtp.SendMail(addr, auth, smtpSettings.FromAddress, []string{toAddress}, msg)
|
||||
case "none":
|
||||
// No encryption (usually port 25 or 587 without STARTTLS)
|
||||
err = smtp.SendMail(addr, auth, smtpSettings.FromAddress, []string{toAddress}, msg)
|
||||
default:
|
||||
return fmt.Errorf("unsupported SMTP encryption type: %s", smtpSettings.Encryption)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send email: %w", err)
|
||||
}
|
||||
log.Info().Msgf("Email sent successfully to %s", toAddress)
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendMailTLS sends an email over a TLS connection.
|
||||
func sendMailTLS(addr string, auth smtp.Auth, from string, to []string, msg []byte) error {
|
||||
conn, err := tls.Dial("tcp", addr, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client, err := smtp.NewClient(conn, smtpSettings.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
if auth != nil {
|
||||
if ok, _ := client.Extension("AUTH"); ok {
|
||||
if err = client.Auth(auth); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err = client.Mail(from); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, addr := range to {
|
||||
if err = client.Rcpt(addr); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
w, err := client.Data()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.Write(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = w.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return client.Quit()
|
||||
}
|
||||
29
pkg/ring/clients/nats.go
Normal file
29
pkg/ring/clients/nats.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package clients
|
||||
|
||||
import (
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var natsConn *nats.Conn
|
||||
|
||||
// InitNATSClient initializes the NATS client connection.
|
||||
func InitNATSClient() {
|
||||
natsURL := viper.GetString("nats.url")
|
||||
if natsURL == "" {
|
||||
log.Fatal().Msg("NATS URL not configured in viper (nats.url)")
|
||||
}
|
||||
|
||||
var err error
|
||||
natsConn, err = nats.Connect(natsURL)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to connect to NATS server")
|
||||
}
|
||||
log.Info().Msgf("Connected to NATS server: %s", natsURL)
|
||||
}
|
||||
|
||||
// GetNATSClient returns the initialized NATS client connection.
|
||||
func GetNATSClient() *nats.Conn {
|
||||
return natsConn
|
||||
}
|
||||
67
pkg/ring/clients/push.go
Normal file
67
pkg/ring/clients/push.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package clients
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
firebase "firebase.google.com/go/v4"
|
||||
"firebase.google.com/go/v4/messaging"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sideshow/apns2"
|
||||
"github.com/sideshow/apns2/token"
|
||||
"github.com/spf13/viper"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
var (
|
||||
apnsClient *apns2.Client
|
||||
firebaseClient *messaging.Client
|
||||
)
|
||||
|
||||
func InitPushClients() {
|
||||
// Initialize APNs Client
|
||||
apnsCertPath := viper.GetString("apns.certificate_path")
|
||||
apnsKeyID := viper.GetString("apns.key_id")
|
||||
apnsTeamID := viper.GetString("apns.team_id")
|
||||
|
||||
if apnsCertPath != "" && apnsKeyID != "" && apnsTeamID != "" {
|
||||
authKey, err := token.AuthKeyFromFile(apnsCertPath)
|
||||
token := &token.Token{
|
||||
AuthKey: authKey,
|
||||
KeyID: apnsKeyID,
|
||||
TeamID: apnsTeamID,
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to create APNs auth key")
|
||||
}
|
||||
apnsClient = apns2.NewTokenClient(token)
|
||||
apnsClient.Production() // Use Production environment
|
||||
log.Info().Msg("APNs client initialized in production mode")
|
||||
} else {
|
||||
log.Warn().Msg("APNs configuration missing. Skipping APNs client initialization.")
|
||||
}
|
||||
|
||||
// Initialize Firebase Client
|
||||
firebaseServiceAccountPath := viper.GetString("firebase.service_account_path")
|
||||
if firebaseServiceAccountPath != "" {
|
||||
opt := option.WithCredentialsFile(firebaseServiceAccountPath)
|
||||
app, err := firebase.NewApp(context.Background(), nil, opt)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to create Firebase app")
|
||||
}
|
||||
firebaseClient, err = app.Messaging(context.Background())
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to create Firebase Messaging client")
|
||||
}
|
||||
log.Info().Msg("Firebase Messaging client initialized")
|
||||
} else {
|
||||
log.Warn().Msg("Firebase service account path missing. Skipping Firebase client initialization.")
|
||||
}
|
||||
}
|
||||
|
||||
func GetAPNsClient() *apns2.Client {
|
||||
return apnsClient
|
||||
}
|
||||
|
||||
func GetFirebaseClient() *messaging.Client {
|
||||
return firebaseClient
|
||||
}
|
||||
19
pkg/ring/infra/db.go
Normal file
19
pkg/ring/infra/db.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package infra
|
||||
|
||||
import (
|
||||
"github.com/spf13/viper"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var Db *gorm.DB
|
||||
|
||||
func ConnectDb() error {
|
||||
dsn := viper.GetString("database.dsn")
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
Db = db
|
||||
return nil
|
||||
}
|
||||
105
pkg/ring/main.go
105
pkg/ring/main.go
@@ -2,17 +2,31 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"git.solsynth.dev/goatworks/turbine/pkg/ring/clients" // Add this import
|
||||
"git.solsynth.dev/goatworks/turbine/pkg/ring/infra"
|
||||
"git.solsynth.dev/goatworks/turbine/pkg/ring/routes"
|
||||
"git.solsynth.dev/goatworks/turbine/pkg/ring/services"
|
||||
"git.solsynth.dev/goatworks/turbine/pkg/ring/websocket" // Add this import
|
||||
"git.solsynth.dev/goatworks/turbine/pkg/shared/hash"
|
||||
pb "git.solsynth.dev/goatworks/turbine/pkg/shared/proto/gen"
|
||||
"git.solsynth.dev/goatworks/turbine/pkg/shared/registrar"
|
||||
|
||||
fiber_websocket "github.com/gofiber/contrib/v3/websocket" // Add this alias import
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
"github.com/spf13/viper"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -31,12 +45,49 @@ func main() {
|
||||
log.Fatal().Err(err).Msg("Failed to read config file...")
|
||||
}
|
||||
log.Info().Msg("Configuration loaded.")
|
||||
|
||||
clients.InitPushClients() // Initialize APNs and Firebase clients
|
||||
clients.InitSMTPSettings() // Initialize SMTP client settings
|
||||
clients.InitNATSClient() // Initialize NATS client
|
||||
|
||||
// --- gRPC Server ---
|
||||
grpcListenAddr := viper.GetString("grpc.listen")
|
||||
if grpcListenAddr == "" {
|
||||
log.Fatal().Msg("grpc.listen not configured")
|
||||
}
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
wsManager := websocket.NewManager(clients.GetNATSClient()) // Pass NATS client to WebSocket Manager
|
||||
ringService := &services.RingServiceServerImpl{
|
||||
WsManager: wsManager, // Inject WebSocket Manager into RingService
|
||||
}
|
||||
pb.RegisterRingServiceServer(grpcServer, ringService)
|
||||
pb.RegisterRingHandlerServiceServer(grpcServer, ringService)
|
||||
|
||||
grpcLis, err := net.Listen("tcp", grpcListenAddr)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msgf("Failed to listen for gRPC: %s", grpcListenAddr)
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Info().Msgf("gRPC server listening on %s", grpcListenAddr)
|
||||
if err := grpcServer.Serve(grpcLis); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to serve gRPC")
|
||||
}
|
||||
}()
|
||||
|
||||
// --- Service Registration ---
|
||||
etcdEndpoints := viper.GetStringSlice("etcd.endpoints")
|
||||
if len(etcdEndpoints) == 0 {
|
||||
log.Fatal().Msg("etcd.endpoints not configured")
|
||||
}
|
||||
|
||||
if err := infra.ConnectDb(); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to connect database.")
|
||||
} else {
|
||||
infra.Db.AutoMigrate()
|
||||
}
|
||||
|
||||
if viper.GetBool("etcd.insecure") {
|
||||
for i, ep := range etcdEndpoints {
|
||||
if !strings.HasPrefix(ep, "http://") && !strings.HasPrefix(ep, "https://") {
|
||||
@@ -61,7 +112,7 @@ func main() {
|
||||
log.Fatal().Err(err).Msg("Invalid listen address")
|
||||
}
|
||||
|
||||
serviceName := "config"
|
||||
serviceName := "ring" // This should probably be "ring"
|
||||
instanceID := fmt.Sprint(hash.Hash(fmt.Sprintf("%s-%s-%d", serviceName, host, port)))[:8]
|
||||
|
||||
err = serviceReg.Register(serviceName, "http", instanceID, host, port, 30)
|
||||
@@ -75,23 +126,45 @@ func main() {
|
||||
ServerHeader: "Turbine Ring",
|
||||
})
|
||||
|
||||
// This is the main endpoint that serves the configuration as JSON.
|
||||
app.Get("/", func(c fiber.Ctx) error {
|
||||
log.Info().Msg("Serving shared configuration as JSON")
|
||||
api := app.Group("/api")
|
||||
{
|
||||
api.Post("/notifications/subscription", routes.CreatePushSubscription)
|
||||
}
|
||||
|
||||
// Use a new Viper instance to read the shared config
|
||||
v := viper.New()
|
||||
v.SetConfigName("shared_config")
|
||||
v.AddConfigPath(".") // Look in the current directory (pkg/config)
|
||||
v.SetConfigType("toml")
|
||||
// Initialize WebSocket Controller
|
||||
wsController := websocket.NewWebSocketController(wsManager)
|
||||
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to read shared_config.toml")
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
|
||||
"error": "could not load configuration",
|
||||
})
|
||||
// WebSocket endpoint
|
||||
app.Use("/ws", func(c fiber.Ctx) error {
|
||||
// Mock authentication based on C# example
|
||||
// In a real scenario, you'd extract user/session from JWT or similar
|
||||
// and set c.Locals("currentUser") and c.Locals("currentSession")
|
||||
c.Locals("currentUser", &pb.Account{Id: uuid.New().String(), Name: "mock_user", Nick: "Mock User"})
|
||||
c.Locals("currentSession", &pb.AuthSession{ClientId: lo.ToPtr(uuid.New().String())})
|
||||
|
||||
if fiber_websocket.IsWebSocketUpgrade(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
return c.JSON(v.AllSettings())
|
||||
return fiber.ErrUpgradeRequired
|
||||
})
|
||||
app.Get("/ws", fiber_websocket.New(wsController.HandleWebSocket))
|
||||
|
||||
// Graceful shutdown
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
<-c
|
||||
log.Info().Msg("Shutting down servers...")
|
||||
if err := app.ShutdownWithTimeout(5 * time.Second); err != nil {
|
||||
log.Error().Err(err).Msg("Fiber server shutdown error")
|
||||
}
|
||||
grpcServer.GracefulStop()
|
||||
log.Info().Msg("Servers gracefully stopped")
|
||||
}()
|
||||
|
||||
err = app.Listen(listenAddr, fiber.ListenConfig{DisableStartupMessage: true})
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to start the server...")
|
||||
}
|
||||
}
|
||||
|
||||
45
pkg/ring/models/notification.go
Normal file
45
pkg/ring/models/notification.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type PushProvider int
|
||||
|
||||
const (
|
||||
PushProviderApple PushProvider = iota
|
||||
PushProviderGoogle
|
||||
)
|
||||
|
||||
type Notification struct {
|
||||
Id uuid.UUID `gorm:"primarykey;type:uuid;default:gen_random_uuid()" json:"id"`
|
||||
Topic string `gorm:"not null;size:1024" json:"topic"`
|
||||
Title *string `json:"title"`
|
||||
Subtitle *string `json:"subtitle"`
|
||||
Content *string `json:"content"`
|
||||
Meta map[string]any `gorm:"type:jsonb;default:'{}'::jsonb;not null" json:"meta"`
|
||||
Priority int `gorm:"default:10" json:"priority"`
|
||||
ViewedAt *time.Time `json:"viewed_at"`
|
||||
AccountId uuid.UUID `gorm:"type:uuid;not null" json:"account_id"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at"`
|
||||
}
|
||||
|
||||
type NotificationPushSubscription struct {
|
||||
Id uuid.UUID `gorm:"primarykey;type:uuid;default:gen_random_uuid()" json:"id"`
|
||||
AccountId uuid.UUID `gorm:"type:uuid;not null;uniqueIndex:account_device_deleted" json:"account_id"`
|
||||
DeviceId string `gorm:"not null;size:8192;uniqueIndex:account_device_deleted" json:"device_id"`
|
||||
DeviceToken string `gorm:"not null;size:8192" json:"device_token"`
|
||||
Provider PushProvider `gorm:"not null" json:"provider"`
|
||||
CountDelivered int `gorm:"default:0" json:"count_delivered"`
|
||||
LastUsedAt *time.Time `json:"last_used_at"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"uniqueIndex:account_device_deleted,index" json:"deleted_at"`
|
||||
}
|
||||
BIN
pkg/ring/ring
Executable file
BIN
pkg/ring/ring
Executable file
Binary file not shown.
7
pkg/ring/routes/notification_api.go
Normal file
7
pkg/ring/routes/notification_api.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package routes
|
||||
|
||||
import "github.com/gofiber/fiber/v3"
|
||||
|
||||
func CreatePushSubscription(c fiber.Ctx) error {
|
||||
return c.JSON(fiber.Map{})
|
||||
}
|
||||
435
pkg/ring/services/ring_service.go
Normal file
435
pkg/ring/services/ring_service.go
Normal file
@@ -0,0 +1,435 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"firebase.google.com/go/v4/messaging"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sideshow/apns2"
|
||||
"github.com/sideshow/apns2/payload"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
"git.solsynth.dev/goatworks/turbine/pkg/ring/clients"
|
||||
"git.solsynth.dev/goatworks/turbine/pkg/ring/infra"
|
||||
"git.solsynth.dev/goatworks/turbine/pkg/ring/models"
|
||||
"git.solsynth.dev/goatworks/turbine/pkg/ring/websocket"
|
||||
pb "git.solsynth.dev/goatworks/turbine/pkg/shared/proto/gen"
|
||||
)
|
||||
|
||||
// RingServiceServerImpl implements the RingServiceServer and RingHandlerServiceServer interfaces
|
||||
type RingServiceServerImpl struct {
|
||||
pb.UnimplementedRingServiceServer
|
||||
pb.UnimplementedRingHandlerServiceServer
|
||||
WsManager *websocket.Manager // Add WebSocket Manager
|
||||
}
|
||||
|
||||
// sendPushNotification helper function to send notifications via APNs or FCM
|
||||
func sendPushNotification(ctx context.Context, notif *models.Notification, sub *models.NotificationPushSubscription) error {
|
||||
if sub.Provider == models.PushProviderApple {
|
||||
apnsClient := clients.GetAPNsClient()
|
||||
if apnsClient == nil {
|
||||
return fmt.Errorf("APNs client not initialized")
|
||||
}
|
||||
|
||||
apnsPayload := payload.NewPayload().
|
||||
AlertTitle(*notif.Title).
|
||||
AlertBody(*notif.Content)
|
||||
|
||||
if notif.Subtitle != nil {
|
||||
apnsPayload.AlertSubtitle(*notif.Subtitle)
|
||||
}
|
||||
|
||||
if notif.Meta != nil {
|
||||
metaBytes, err := json.Marshal(notif.Meta)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal notification meta for APNs")
|
||||
} else {
|
||||
apnsPayload.Custom("meta", string(metaBytes))
|
||||
}
|
||||
}
|
||||
|
||||
notification := &apns2.Notification{
|
||||
DeviceToken: sub.DeviceToken,
|
||||
Topic: notif.Topic,
|
||||
Payload: apnsPayload,
|
||||
}
|
||||
|
||||
res, err := apnsClient.PushWithContext(ctx, notification)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send APNs push notification: %w", err)
|
||||
}
|
||||
if res.Sent() {
|
||||
log.Info().Msgf("APNs notification sent successfully to device %s (token: %s)", sub.DeviceId, sub.DeviceToken)
|
||||
} else {
|
||||
log.Error().Msgf("APNs notification failed to send to device %s (token: %s) with reason: %s", sub.DeviceId, sub.DeviceToken, res.Reason)
|
||||
if res.StatusCode == 410 { // Expired token
|
||||
log.Warn().Msgf("APNs token for device %s expired. Deleting subscription.", sub.DeviceId)
|
||||
// TODO: Delete the subscription
|
||||
}
|
||||
return fmt.Errorf("APNs push notification failed: %s", res.Reason)
|
||||
}
|
||||
} else if sub.Provider == models.PushProviderGoogle {
|
||||
firebaseClient := clients.GetFirebaseClient()
|
||||
if firebaseClient == nil {
|
||||
return fmt.Errorf("Firebase client not initialized")
|
||||
}
|
||||
|
||||
message := &messaging.Message{
|
||||
Token: sub.DeviceToken,
|
||||
Notification: &messaging.Notification{
|
||||
Title: *notif.Title,
|
||||
Body: *notif.Content,
|
||||
},
|
||||
Data: make(map[string]string),
|
||||
}
|
||||
|
||||
if notif.Meta != nil {
|
||||
for k, v := range notif.Meta {
|
||||
message.Data[k] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
res, err := firebaseClient.Send(ctx, message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send Firebase push notification: %w", err)
|
||||
}
|
||||
log.Info().Msgf("Firebase notification sent successfully to device %s (token: %s): %s", sub.DeviceId, sub.DeviceToken, res)
|
||||
} else {
|
||||
return fmt.Errorf("unsupported push provider: %v", sub.Provider)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendEmail implements proto.RingServiceServer
|
||||
func (s *RingServiceServerImpl) SendEmail(ctx context.Context, req *pb.SendEmailRequest) (*emptypb.Empty, error) {
|
||||
log.Info().Msgf("Received SendEmail request: %+v", req.GetEmail())
|
||||
|
||||
email := req.GetEmail()
|
||||
if email == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "email message is nil")
|
||||
}
|
||||
|
||||
if email.ToAddress == "" || email.Subject == "" || email.Body == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "to_address, subject, and body cannot be empty")
|
||||
}
|
||||
|
||||
err := clients.SendEmail(email.ToAddress, email.Subject, email.Body)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to send email")
|
||||
return nil, status.Errorf(codes.Internal, "failed to send email: %v", err)
|
||||
}
|
||||
|
||||
log.Info().Msgf("Email sent successfully to %s", email.ToAddress)
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
// PushWebSocketPacket implements proto.RingServiceServer
|
||||
func (s *RingServiceServerImpl) PushWebSocketPacket(ctx context.Context, req *pb.PushWebSocketPacketRequest) (*emptypb.Empty, error) {
|
||||
log.Info().Msgf("Received PushWebSocketPacket request for user %s: %+v", req.GetUserId(), req.GetPacket())
|
||||
|
||||
if s.WsManager == nil {
|
||||
return nil, status.Errorf(codes.Unavailable, "WebSocket manager not initialized")
|
||||
}
|
||||
|
||||
accountID, err := uuid.Parse(req.GetUserId())
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user ID: %v", err)
|
||||
}
|
||||
|
||||
packet := websocket.FromProtoValue(req.GetPacket())
|
||||
s.WsManager.SendPacketToAccount(accountID, packet)
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
// PushWebSocketPacketToUsers implements proto.RingServiceServer
|
||||
func (s *RingServiceServerImpl) PushWebSocketPacketToUsers(ctx context.Context, req *pb.PushWebSocketPacketToUsersRequest) (*emptypb.Empty, error) {
|
||||
log.Info().Msgf("Received PushWebSocketPacketToUsers request for users %v: %+v", req.GetUserIds(), req.GetPacket())
|
||||
|
||||
if s.WsManager == nil {
|
||||
return nil, status.Errorf(codes.Unavailable, "WebSocket manager not initialized")
|
||||
}
|
||||
|
||||
packet := websocket.FromProtoValue(req.GetPacket())
|
||||
for _, userID := range req.GetUserIds() {
|
||||
accountID, err := uuid.Parse(userID)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msgf("Invalid user ID in batch: %s, skipping", userID)
|
||||
continue
|
||||
}
|
||||
s.WsManager.SendPacketToAccount(accountID, packet)
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
// PushWebSocketPacketToDevice implements proto.RingServiceServer
|
||||
func (s *RingServiceServerImpl) PushWebSocketPacketToDevice(ctx context.Context, req *pb.PushWebSocketPacketToDeviceRequest) (*emptypb.Empty, error) {
|
||||
log.Info().Msgf("Received PushWebSocketPacketToDevice request for device %s: %+v", req.GetDeviceId(), req.GetPacket())
|
||||
|
||||
if s.WsManager == nil {
|
||||
return nil, status.Errorf(codes.Unavailable, "WebSocket manager not initialized")
|
||||
}
|
||||
|
||||
packet := websocket.FromProtoValue(req.GetPacket())
|
||||
s.WsManager.SendPacketToDevice(req.GetDeviceId(), packet)
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
// PushWebSocketPacketToDevices implements proto.RingServiceServer
|
||||
func (s *RingServiceServerImpl) PushWebSocketPacketToDevices(ctx context.Context, req *pb.PushWebSocketPacketToDevicesRequest) (*emptypb.Empty, error) {
|
||||
log.Info().Msgf("Received PushWebSocketPacketToDevices request for devices %v: %+v", req.GetDeviceIds(), req.GetPacket())
|
||||
|
||||
if s.WsManager == nil {
|
||||
return nil, status.Errorf(codes.Unavailable, "WebSocket manager not initialized")
|
||||
}
|
||||
|
||||
packet := websocket.FromProtoValue(req.GetPacket())
|
||||
for _, deviceID := range req.GetDeviceIds() {
|
||||
s.WsManager.SendPacketToDevice(deviceID, packet)
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
// SendPushNotificationToUser implements proto.RingServiceServer
|
||||
func (s *RingServiceServerImpl) SendPushNotificationToUser(ctx context.Context, req *pb.SendPushNotificationToUserRequest) (*emptypb.Empty, error) {
|
||||
log.Info().Msgf("Received SendPushNotificationToUser request for user %s: %+v", req.GetUserId(), req.GetNotification())
|
||||
|
||||
// 1. Parse incoming notification and save to DB
|
||||
accountID, err := uuid.Parse(req.GetUserId())
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid account ID: %v", err)
|
||||
}
|
||||
|
||||
notif := &models.Notification{
|
||||
Topic: req.GetNotification().GetTopic(),
|
||||
Title: &req.GetNotification().Title,
|
||||
Content: &req.GetNotification().Body,
|
||||
AccountId: accountID,
|
||||
}
|
||||
|
||||
if req.GetNotification().Subtitle != "" {
|
||||
subtitle := req.GetNotification().Subtitle
|
||||
notif.Subtitle = &subtitle
|
||||
}
|
||||
|
||||
if req.GetNotification().Meta != nil {
|
||||
var metaMap map[string]any
|
||||
if err := json.Unmarshal(req.GetNotification().GetMeta(), &metaMap); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to unmarshal notification meta, sending without it")
|
||||
} else {
|
||||
notif.Meta = metaMap
|
||||
}
|
||||
}
|
||||
|
||||
if err := infra.Db.WithContext(ctx).Create(notif).Error; err != nil {
|
||||
log.Error().Err(err).Msg("Failed to save notification to database")
|
||||
return nil, status.Errorf(codes.Internal, "failed to save notification: %v", err)
|
||||
}
|
||||
log.Info().Msgf("Notification saved to DB: %s", notif.Id.String())
|
||||
|
||||
// 2. Retrieve push subscriptions for the user
|
||||
var subscriptions []models.NotificationPushSubscription
|
||||
if err := infra.Db.WithContext(ctx).Where("account_id = ?", req.GetUserId()).Find(&subscriptions).Error; err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to retrieve subscriptions for account %s", req.GetUserId())
|
||||
return nil, status.Errorf(codes.Internal, "failed to retrieve subscriptions: %v", err)
|
||||
}
|
||||
|
||||
if len(subscriptions) == 0 {
|
||||
log.Info().Msgf("No active push subscriptions found for user %s", req.GetUserId())
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
// 3. Send notification to each subscription
|
||||
for _, sub := range subscriptions {
|
||||
sub := sub // Create a local copy for the goroutine
|
||||
go func() {
|
||||
if err := sendPushNotification(ctx, notif, &sub); err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to send push notification to device %s", sub.DeviceId)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
// SendPushNotificationToUsers implements proto.RingServiceServer
|
||||
func (s *RingServiceServerImpl) SendPushNotificationToUsers(ctx context.Context, req *pb.SendPushNotificationToUsersRequest) (*emptypb.Empty, error) {
|
||||
log.Info().Msgf("Received SendPushNotificationToUsers request for users %v: %+v", req.GetUserIds(), req.GetNotification())
|
||||
|
||||
// 1. Parse incoming notification (one notification for all users)
|
||||
// We'll create separate DB entries for each user if needed, or link to a single notification.
|
||||
// For simplicity, let's assume we save one notification per user for now.
|
||||
// A better approach for many users might be a single notification record linked to multiple user-notification bridges.
|
||||
|
||||
// Extract common notification details
|
||||
notificationProto := req.GetNotification()
|
||||
metaMap := make(map[string]any)
|
||||
if notificationProto.Meta != nil {
|
||||
if err := json.Unmarshal(notificationProto.GetMeta(), &metaMap); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to unmarshal notification meta for batch, sending without it")
|
||||
}
|
||||
}
|
||||
|
||||
for _, userID := range req.GetUserIds() {
|
||||
accountID, err := uuid.Parse(userID)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msgf("Invalid account ID in batch: %s, skipping", userID)
|
||||
continue
|
||||
}
|
||||
|
||||
notif := &models.Notification{
|
||||
Topic: notificationProto.GetTopic(),
|
||||
Title: ¬ificationProto.Title,
|
||||
Content: ¬ificationProto.Body,
|
||||
Meta: metaMap,
|
||||
AccountId: accountID,
|
||||
}
|
||||
|
||||
if notificationProto.Subtitle != "" {
|
||||
subtitle := notificationProto.Subtitle
|
||||
notif.Subtitle = &subtitle
|
||||
}
|
||||
|
||||
if err := infra.Db.WithContext(ctx).Create(notif).Error; err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to save notification to database for user %s: %v", userID, err)
|
||||
continue
|
||||
}
|
||||
log.Info().Msgf("Notification saved to DB for user %s: %s", userID, notif.Id.String())
|
||||
|
||||
// Retrieve push subscriptions for the current user
|
||||
var subscriptions []models.NotificationPushSubscription
|
||||
if err := infra.Db.WithContext(ctx).Where("account_id = ?", userID).Find(&subscriptions).Error; err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to retrieve subscriptions for account %s in batch: %v", userID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(subscriptions) == 0 {
|
||||
log.Info().Msgf("No active push subscriptions found for user %s in batch", userID)
|
||||
continue
|
||||
}
|
||||
|
||||
// Send notification to each subscription for the current user
|
||||
for _, sub := range subscriptions {
|
||||
sub := sub // Create a local copy for the goroutine
|
||||
go func() {
|
||||
if err := sendPushNotification(ctx, notif, &sub); err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to send push notification to device %s for user %s", sub.DeviceId, userID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
// UnsubscribePushNotifications implements proto.RingServiceServer
|
||||
func (s *RingServiceServerImpl) UnsubscribePushNotifications(ctx context.Context, req *pb.UnsubscribePushNotificationsRequest) (*emptypb.Empty, error) {
|
||||
log.Info().Msgf("Received UnsubscribePushNotifications request for device %s", req.GetDeviceId())
|
||||
|
||||
if req.GetDeviceId() == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "device_id cannot be empty")
|
||||
}
|
||||
|
||||
result := infra.Db.WithContext(ctx).Where("device_id = ?", req.GetDeviceId()).Delete(&models.NotificationPushSubscription{})
|
||||
if result.Error != nil {
|
||||
log.Error().Err(result.Error).Msgf("Failed to unsubscribe device %s", req.GetDeviceId())
|
||||
return nil, status.Errorf(codes.Internal, "failed to unsubscribe: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
log.Info().Msgf("No subscription found for device %s to unsubscribe", req.GetDeviceId())
|
||||
} else {
|
||||
log.Info().Msgf("Successfully unsubscribed device %s", req.GetDeviceId())
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
// GetWebsocketConnectionStatus implements proto.RingServiceServer
|
||||
func (s *RingServiceServerImpl) GetWebsocketConnectionStatus(ctx context.Context, req *pb.GetWebsocketConnectionStatusRequest) (*pb.GetWebsocketConnectionStatusResponse, error) {
|
||||
log.Info().Msgf("Received GetWebsocketConnectionStatus request: %+v", req)
|
||||
|
||||
if s.WsManager == nil {
|
||||
return nil, status.Errorf(codes.Unavailable, "WebSocket manager not initialized")
|
||||
}
|
||||
|
||||
isConnected := false
|
||||
switch id := req.GetId().(type) {
|
||||
case *pb.GetWebsocketConnectionStatusRequest_DeviceId:
|
||||
isConnected = s.WsManager.GetDeviceIsConnected(id.DeviceId)
|
||||
case *pb.GetWebsocketConnectionStatusRequest_UserId:
|
||||
accountID, err := uuid.Parse(id.UserId)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid user ID: %v", err)
|
||||
}
|
||||
isConnected = s.WsManager.GetAccountIsConnected(accountID)
|
||||
default:
|
||||
return nil, status.Errorf(codes.InvalidArgument, "either device_id or user_id must be provided")
|
||||
}
|
||||
|
||||
return &pb.GetWebsocketConnectionStatusResponse{IsConnected: isConnected}, nil
|
||||
}
|
||||
|
||||
// GetWebsocketConnectionStatusBatch implements proto.RingServiceServer
|
||||
func (s *RingServiceServerImpl) GetWebsocketConnectionStatusBatch(ctx context.Context, req *pb.GetWebsocketConnectionStatusBatchRequest) (*pb.GetWebsocketConnectionStatusBatchResponse, error) {
|
||||
log.Info().Msgf("Received GetWebsocketConnectionStatusBatch request: %+v", req)
|
||||
|
||||
if s.WsManager == nil {
|
||||
return nil, status.Errorf(codes.Unavailable, "WebSocket manager not initialized")
|
||||
}
|
||||
|
||||
response := &pb.GetWebsocketConnectionStatusBatchResponse{
|
||||
IsConnected: make(map[string]bool),
|
||||
}
|
||||
|
||||
for _, userID := range req.GetUsersId() {
|
||||
accountID, err := uuid.Parse(userID)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msgf("Invalid user ID in batch: %s, skipping", userID)
|
||||
response.IsConnected[userID] = false // Indicate connection status as false for invalid IDs
|
||||
continue
|
||||
}
|
||||
response.IsConnected[userID] = s.WsManager.GetAccountIsConnected(accountID)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// ReceiveWebSocketPacket implements proto.RingHandlerServiceServer
|
||||
func (s *RingServiceServerImpl) ReceiveWebSocketPacket(ctx context.Context, req *pb.ReceiveWebSocketPacketRequest) (*emptypb.Empty, error) {
|
||||
log.Info().Msgf("Received ReceiveWebSocketPacket request: %+v", req)
|
||||
|
||||
if s.WsManager == nil {
|
||||
return nil, status.Errorf(codes.Unavailable, "WebSocket manager not initialized")
|
||||
}
|
||||
|
||||
packet := websocket.FromProtoValue(req.GetPacket())
|
||||
|
||||
// The C# HandlePacket expects current user and device ID.
|
||||
// For this gRPC endpoint, we can use the account and device_id from the request.
|
||||
if req.GetAccount() == nil || req.GetAccount().GetId() == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "account information missing in request")
|
||||
}
|
||||
if req.GetDeviceId() == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "device_id missing in request")
|
||||
}
|
||||
|
||||
// Assuming the request comes from a trusted source, we can use the provided account.
|
||||
// We don't have a direct *websocket.Conn here, so we can't send error responses back on a direct WebSocket.
|
||||
// Errors will be logged and returned as gRPC errors.
|
||||
err := s.WsManager.HandlePacket(req.GetAccount(), req.GetDeviceId(), packet, nil) // Pass nil for *websocket.Conn
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to handle received WebSocket packet via gRPC")
|
||||
return nil, status.Errorf(codes.Internal, "failed to process packet: %v", err)
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
144
pkg/ring/websocket/controller.go
Normal file
144
pkg/ring/websocket/controller.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gofiber/contrib/v3/websocket"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
|
||||
pb "git.solsynth.dev/goatworks/turbine/pkg/shared/proto/gen"
|
||||
)
|
||||
|
||||
// WebSocketController handles the WebSocket endpoint.
|
||||
type WebSocketController struct {
|
||||
Manager *Manager
|
||||
}
|
||||
|
||||
// NewWebSocketController creates a new WebSocketController.
|
||||
func NewWebSocketController(manager *Manager) *WebSocketController {
|
||||
return &WebSocketController{
|
||||
Manager: manager,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleWebSocket is the main handler for the /ws endpoint.
|
||||
func (wc *WebSocketController) HandleWebSocket(c *websocket.Conn) {
|
||||
// Mock Authentication for now based on C# example
|
||||
// In a real scenario, this would involve JWT verification, session lookup, etc.
|
||||
// For demonstration, we'll assume a dummy user and session.
|
||||
// The C# code uses HttpContext.Items to get CurrentUser and CurrentSession.
|
||||
// In Go Fiber, we can pass this through Locals or use middleware to set it up.
|
||||
// For now, let's create a dummy user and session.
|
||||
|
||||
// TODO: Replace with actual authentication logic
|
||||
// For now, assume a dummy account and session
|
||||
// Based on the C# code, CurrentUser and CurrentSession are pb.Account and pb.AuthSession
|
||||
dummyAccount := &pb.Account{
|
||||
Id: uuid.New().String(),
|
||||
Name: "dummy_user",
|
||||
Nick: "Dummy User",
|
||||
}
|
||||
dummySession := &pb.AuthSession{
|
||||
ClientId: lo.ToPtr(uuid.New().String()), // This is used as deviceId if not present
|
||||
}
|
||||
|
||||
// Device ID handling
|
||||
deviceAlt := c.Query("deviceAlt")
|
||||
if deviceAlt != "" {
|
||||
allowedDeviceAlternative := []string{"watch"} // Hardcoded for now
|
||||
found := false
|
||||
for _, alt := range allowedDeviceAlternative {
|
||||
if deviceAlt == alt {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
log.Warn().Msgf("Unsupported device alternative: %s", deviceAlt)
|
||||
bytes, _ := newErrorPacket("Unsupported device alternative: " + deviceAlt).ToBytes()
|
||||
c.WriteMessage(websocket.BinaryMessage, bytes)
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
accountID := uuid.MustParse(dummyAccount.Id)
|
||||
deviceIDStr := ""
|
||||
if dummySession.ClientId == nil {
|
||||
deviceIDStr = uuid.New().String()
|
||||
} else {
|
||||
deviceIDStr = *dummySession.ClientId
|
||||
}
|
||||
if deviceAlt != "" {
|
||||
deviceIDStr = fmt.Sprintf("%s+%s", deviceIDStr, deviceAlt)
|
||||
}
|
||||
|
||||
// Setup connection context for cancellation
|
||||
cancel := func() {} // Placeholder
|
||||
connectionKey := ConnectionKey{AccountID: accountID, DeviceID: deviceIDStr}
|
||||
|
||||
// Add connection to manager
|
||||
if !wc.Manager.TryAdd(connectionKey, c, cancel) {
|
||||
bytes, _ := newErrorPacket("Too many connections from the same device and account.").ToBytes()
|
||||
c.WriteMessage(websocket.BinaryMessage, bytes)
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Connection established with user @%s#%s and device #%s", dummyAccount.Name, dummyAccount.Id, deviceIDStr)
|
||||
|
||||
// Publish WebSocket connected event
|
||||
wc.Manager.PublishWebSocketConnectedEvent(accountID, deviceIDStr, false) // isOffline is false on connection
|
||||
|
||||
defer func() {
|
||||
wc.Manager.Disconnect(connectionKey, "Client disconnected.")
|
||||
// Publish WebSocket disconnected event
|
||||
isOffline := !wc.Manager.GetAccountIsConnected(accountID) // Check if account is completely offline
|
||||
wc.Manager.PublishWebSocketDisconnectedEvent(accountID, deviceIDStr, isOffline)
|
||||
log.Debug().Msgf("Connection disconnected with user @%s#%s and device #%s", dummyAccount.Name, dummyAccount.Id, deviceIDStr)
|
||||
}()
|
||||
|
||||
// Main event loop
|
||||
for {
|
||||
mt, msg, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("WebSocket read error")
|
||||
break
|
||||
}
|
||||
|
||||
if mt == websocket.CloseMessage {
|
||||
log.Info().Msg("Received close message from client")
|
||||
break
|
||||
}
|
||||
|
||||
if mt != websocket.BinaryMessage {
|
||||
log.Warn().Msgf("Received non-binary message type: %d", mt)
|
||||
continue
|
||||
}
|
||||
|
||||
packet, err := FromBytes(msg)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to deserialize WebSocket packet")
|
||||
bytes, _ := newErrorPacket("Failed to deserialize packet").ToBytes()
|
||||
c.WriteMessage(websocket.BinaryMessage, bytes)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := wc.Manager.HandlePacket(dummyAccount, deviceIDStr, packet, c); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to handle incoming WebSocket packet")
|
||||
bytes, _ := newErrorPacket("Failed to process packet").ToBytes()
|
||||
c.WriteMessage(websocket.BinaryMessage, bytes)
|
||||
// Depending on error, might want to close connection
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newErrorPacket creates a new WebSocketPacket with an error type.
|
||||
func newErrorPacket(message string) *WebSocketPacket {
|
||||
return &WebSocketPacket{
|
||||
Type: WebSocketPacketTypeError,
|
||||
ErrorMessage: message,
|
||||
}
|
||||
}
|
||||
292
pkg/ring/websocket/manager.go
Normal file
292
pkg/ring/websocket/manager.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/gofiber/contrib/v3/websocket"
|
||||
"github.com/google/uuid"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/rs/zerolog/log"
|
||||
"google.golang.org/protobuf/proto" // Import for proto.Marshal
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
|
||||
pb "git.solsynth.dev/goatworks/turbine/pkg/shared/proto/gen"
|
||||
)
|
||||
|
||||
// ConnectionKey represents the unique identifier for a WebSocket connection.
|
||||
type ConnectionKey struct {
|
||||
AccountID uuid.UUID
|
||||
DeviceID string
|
||||
}
|
||||
|
||||
// ConnectionState holds the WebSocket connection and its cancellation context.
|
||||
type ConnectionState struct {
|
||||
Conn *websocket.Conn
|
||||
Cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// Manager manages active WebSocket connections.
|
||||
type Manager struct {
|
||||
connections sync.Map // Map[ConnectionKey]*ConnectionState
|
||||
natsClient *nats.Conn
|
||||
}
|
||||
|
||||
// NewManager creates a new WebSocket Manager.
|
||||
func NewManager(natsClient *nats.Conn) *Manager {
|
||||
return &Manager{
|
||||
natsClient: natsClient,
|
||||
}
|
||||
}
|
||||
|
||||
// TryAdd attempts to add a new connection. If a connection with the same key exists,
|
||||
// it disconnects the old one and then adds the new one.
|
||||
func (m *Manager) TryAdd(key ConnectionKey, conn *websocket.Conn, cancel context.CancelFunc) bool {
|
||||
// Disconnect existing connection with the same identifier, if any
|
||||
if _, loaded := m.connections.Load(key); loaded {
|
||||
log.Warn().Msgf("Duplicate connection detected for %s:%s. Disconnecting old one.", key.AccountID, key.DeviceID)
|
||||
m.Disconnect(key, "Just connected somewhere else with the same identifier.")
|
||||
}
|
||||
|
||||
m.connections.Store(key, &ConnectionState{Conn: conn, Cancel: cancel})
|
||||
log.Info().Msgf("Connection established for user %s and device %s", key.AccountID, key.DeviceID)
|
||||
return true
|
||||
}
|
||||
|
||||
// Disconnect removes a connection and closes it.
|
||||
func (m *Manager) Disconnect(key ConnectionKey, reason string) {
|
||||
if stateAny, loaded := m.connections.LoadAndDelete(key); loaded {
|
||||
state := stateAny.(ConnectionState)
|
||||
|
||||
// Cancel the context to stop any goroutines associated with this connection
|
||||
if state.Cancel != nil {
|
||||
state.Cancel()
|
||||
}
|
||||
|
||||
// Close the WebSocket connection
|
||||
err := state.Conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, reason))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Error sending close message to WebSocket for %s:%s", key.AccountID, key.DeviceID)
|
||||
}
|
||||
err = state.Conn.Close() // Ensure the connection is closed
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Error closing WebSocket for %s:%s", key.AccountID, key.DeviceID)
|
||||
}
|
||||
|
||||
log.Info().Msgf("Connection disconnected for user %s and device %s. Reason: %s", key.AccountID, key.DeviceID, reason)
|
||||
}
|
||||
}
|
||||
|
||||
// GetDeviceIsConnected checks if any connection exists for a given device ID.
|
||||
func (m *Manager) GetDeviceIsConnected(deviceID string) bool {
|
||||
var isConnected bool
|
||||
m.connections.Range(func(k, v interface{}) bool {
|
||||
connKey := k.(ConnectionKey)
|
||||
if connKey.DeviceID == deviceID {
|
||||
isConnected = true
|
||||
return false // Stop iteration
|
||||
}
|
||||
return true
|
||||
})
|
||||
return isConnected
|
||||
}
|
||||
|
||||
// GetAccountIsConnected checks if any connection exists for a given account ID.
|
||||
func (m *Manager) GetAccountIsConnected(accountID uuid.UUID) bool {
|
||||
var isConnected bool
|
||||
m.connections.Range(func(k, v interface{}) bool {
|
||||
connKey := k.(ConnectionKey)
|
||||
if connKey.AccountID == accountID {
|
||||
isConnected = true
|
||||
return false // Stop iteration
|
||||
}
|
||||
return true
|
||||
})
|
||||
return isConnected
|
||||
}
|
||||
|
||||
// SendPacketToAccount sends a WebSocketPacket to all connections for a given account ID.
|
||||
func (m *Manager) SendPacketToAccount(accountID uuid.UUID, packet *WebSocketPacket) {
|
||||
packetBytes, err := packet.ToBytes()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to marshal packet for account %s", accountID)
|
||||
return
|
||||
}
|
||||
|
||||
m.connections.Range(func(k, v interface{}) bool {
|
||||
connKey := k.(ConnectionKey)
|
||||
if connKey.AccountID == accountID {
|
||||
state := v.(*ConnectionState)
|
||||
if state.Conn != nil {
|
||||
err := state.Conn.WriteMessage(websocket.BinaryMessage, packetBytes)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to send packet to account %s, device %s", accountID, connKey.DeviceID)
|
||||
// Optionally, disconnect this problematic connection
|
||||
// m.Disconnect(connKey, "Failed to send packet")
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// SendPacketToDevice sends a WebSocketPacket to all connections for a given device ID.
|
||||
func (m *Manager) SendPacketToDevice(deviceID string, packet *WebSocketPacket) {
|
||||
packetBytes, err := packet.ToBytes()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to marshal packet for device %s", deviceID)
|
||||
return
|
||||
}
|
||||
|
||||
m.connections.Range(func(k, v interface{}) bool {
|
||||
connKey := k.(ConnectionKey)
|
||||
if connKey.DeviceID == deviceID {
|
||||
state := v.(*ConnectionState)
|
||||
if state.Conn != nil {
|
||||
err := state.Conn.WriteMessage(websocket.BinaryMessage, packetBytes)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to send packet to device %s, account %s", deviceID, connKey.AccountID)
|
||||
// Optionally, disconnect this problematic connection
|
||||
// m.Disconnect(connKey, "Failed to send packet")
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// PublishWebSocketConnectedEvent publishes a WebSocketConnectedEvent to NATS.
|
||||
func (m *Manager) PublishWebSocketConnectedEvent(accountID uuid.UUID, deviceID string, isOffline bool) {
|
||||
if m.natsClient == nil {
|
||||
log.Warn().Msg("NATS client not initialized. Cannot publish WebSocketConnectedEvent.")
|
||||
return
|
||||
}
|
||||
|
||||
event := &pb.WebSocketConnectedEvent{
|
||||
AccountId: wrapperspb.String(accountID.String()),
|
||||
DeviceId: deviceID,
|
||||
IsOffline: isOffline,
|
||||
}
|
||||
eventBytes, err := proto.Marshal(event)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal WebSocketConnectedEvent")
|
||||
return
|
||||
}
|
||||
|
||||
// Assuming WebSocketConnectedEvent.Type is a constant or can be derived
|
||||
// For now, let's use a hardcoded subject. This needs to be consistent with C# code.
|
||||
// C# uses `WebSocketConnectedEvent.Type`
|
||||
// TODO: Define NATS subjects in a centralized place or derive from proto events.
|
||||
natsSubject := "turbine.websocket.connected"
|
||||
err = m.natsClient.Publish(natsSubject, eventBytes)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to publish WebSocketConnectedEvent to NATS")
|
||||
} else {
|
||||
log.Info().Msgf("Published WebSocketConnectedEvent for %s:%s to NATS subject %s", accountID, deviceID, natsSubject)
|
||||
}
|
||||
}
|
||||
|
||||
// PublishWebSocketDisconnectedEvent publishes a WebSocketDisconnectedEvent to NATS.
|
||||
func (m *Manager) PublishWebSocketDisconnectedEvent(accountID uuid.UUID, deviceID string, isOffline bool) {
|
||||
if m.natsClient == nil {
|
||||
log.Warn().Msg("NATS client not initialized. Cannot publish WebSocketDisconnectedEvent.")
|
||||
return
|
||||
}
|
||||
|
||||
event := &pb.WebSocketDisconnectedEvent{
|
||||
AccountId: wrapperspb.String(accountID.String()),
|
||||
DeviceId: deviceID,
|
||||
IsOffline: isOffline,
|
||||
}
|
||||
eventBytes, err := proto.Marshal(event)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal WebSocketDisconnectedEvent")
|
||||
return
|
||||
}
|
||||
|
||||
// Assuming WebSocketDisconnectedEvent.Type is a constant or can be derived
|
||||
natsSubject := "turbine.websocket.disconnected"
|
||||
err = m.natsClient.Publish(natsSubject, eventBytes)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to publish WebSocketDisconnectedEvent to NATS")
|
||||
} else {
|
||||
log.Info().Msgf("Published WebSocketDisconnectedEvent for %s:%s to NATS subject %s", accountID, deviceID, natsSubject)
|
||||
}
|
||||
}
|
||||
|
||||
// HandlePacket processes incoming WebSocketPacket.
|
||||
func (m *Manager) HandlePacket(currentUser *pb.Account, deviceID string, packet *WebSocketPacket, conn *websocket.Conn) error {
|
||||
switch packet.Type {
|
||||
case WebSocketPacketTypePing:
|
||||
pongPacket := &WebSocketPacket{Type: WebSocketPacketTypePong}
|
||||
pongBytes, err := pongPacket.ToBytes()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal pong packet")
|
||||
return err
|
||||
}
|
||||
return conn.WriteMessage(websocket.BinaryMessage, pongBytes)
|
||||
case WebSocketPacketTypeError:
|
||||
log.Error().Msgf("Received error packet from device %s: %s", deviceID, packet.ErrorMessage)
|
||||
return nil // Or handle error appropriately
|
||||
default:
|
||||
if packet.Endpoint != "" {
|
||||
return m.forwardPacketToNATS(currentUser, deviceID, packet)
|
||||
} else {
|
||||
errorPacket := &WebSocketPacket{
|
||||
Type: WebSocketPacketTypeError,
|
||||
ErrorMessage: fmt.Sprintf("Unprocessable packet: %s", packet.Type),
|
||||
}
|
||||
errorBytes, err := errorPacket.ToBytes()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal error packet")
|
||||
return err
|
||||
}
|
||||
return conn.WriteMessage(websocket.BinaryMessage, errorBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) forwardPacketToNATS(currentUser *pb.Account, deviceID string, packet *WebSocketPacket) error {
|
||||
if m.natsClient == nil {
|
||||
log.Warn().Msg("NATS client not initialized. Cannot forward packets to NATS.")
|
||||
return fmt.Errorf("NATS client not initialized")
|
||||
}
|
||||
|
||||
packetBytes, err := packet.ToBytes()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal WebSocketPacket for NATS forwarding")
|
||||
return err
|
||||
}
|
||||
|
||||
// Convert currentUser.Id to string for the proto message
|
||||
var accountIDStr string
|
||||
if currentUser != nil && currentUser.Id != "" {
|
||||
accountIDStr = currentUser.Id
|
||||
} else {
|
||||
log.Warn().Msg("CurrentUser or CurrentUser.Id is nil/empty for NATS forwarding")
|
||||
return fmt.Errorf("current user ID is missing")
|
||||
}
|
||||
|
||||
event := &pb.WebSocketPacketEvent{
|
||||
AccountId: wrapperspb.String(accountIDStr),
|
||||
DeviceId: deviceID,
|
||||
PacketBytes: packetBytes,
|
||||
}
|
||||
eventBytes, err := proto.Marshal(event)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal WebSocketPacketEvent for NATS")
|
||||
return err
|
||||
}
|
||||
|
||||
// C# uses WebSocketPacketEvent.SubjectPrefix + endpoint
|
||||
// TODO: Centralize NATS subject definitions
|
||||
natsSubject := fmt.Sprintf("turbine.websocket.packet.%s", packet.Endpoint)
|
||||
err = m.natsClient.Publish(natsSubject, eventBytes)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to publish WebSocketPacketEvent to NATS")
|
||||
return err
|
||||
}
|
||||
log.Info().Msgf("Forwarded packet to NATS subject %s from device %s", natsSubject, deviceID)
|
||||
return nil
|
||||
}
|
||||
82
pkg/ring/websocket/models.go
Normal file
82
pkg/ring/websocket/models.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
pb "git.solsynth.dev/goatworks/turbine/pkg/shared/proto/gen"
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
)
|
||||
|
||||
// WebSocketPacket represents a WebSocket message packet.
|
||||
type WebSocketPacket struct {
|
||||
Type string `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"` // Use json.RawMessage to delay deserialization
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
}
|
||||
|
||||
// ToBytes serializes the WebSocketPacket to a byte array for sending over WebSocket.
|
||||
func (w *WebSocketPacket) ToBytes() ([]byte, error) {
|
||||
return json.Marshal(w)
|
||||
}
|
||||
|
||||
// FromBytes deserializes a byte array into a WebSocketPacket.
|
||||
func FromBytes(bytes []byte) (*WebSocketPacket, error) {
|
||||
var packet WebSocketPacket
|
||||
err := json.Unmarshal(bytes, &packet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to deserialize WebSocketPacket: %w", err)
|
||||
}
|
||||
return &packet, nil
|
||||
}
|
||||
|
||||
// GetData deserializes the Data property to the specified type T.
|
||||
func (w *WebSocketPacket) GetData(v interface{}) error {
|
||||
if w.Data == nil {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(w.Data, v)
|
||||
}
|
||||
|
||||
// ToProtoValue converts the WebSocketPacket to its protobuf equivalent.
|
||||
func (w *WebSocketPacket) ToProtoValue() *pb.WebSocketPacket {
|
||||
var dataBytes []byte
|
||||
if w.Data != nil {
|
||||
dataBytes = w.Data
|
||||
}
|
||||
var errorMessage string
|
||||
if w.ErrorMessage != "" {
|
||||
errorMessage = w.ErrorMessage
|
||||
}
|
||||
return &pb.WebSocketPacket{
|
||||
Type: w.Type,
|
||||
Data: dataBytes,
|
||||
ErrorMessage: wrapperspb.String(errorMessage),
|
||||
}
|
||||
}
|
||||
|
||||
// FromProtoValue converts a protobuf WebSocketPacket to its Go struct equivalent.
|
||||
func FromProtoValue(packet *pb.WebSocketPacket) *WebSocketPacket {
|
||||
var data json.RawMessage
|
||||
if packet.Data != nil {
|
||||
data = json.RawMessage(packet.Data)
|
||||
}
|
||||
var errorMessage string
|
||||
if packet.ErrorMessage != nil {
|
||||
errorMessage = packet.ErrorMessage.GetValue()
|
||||
}
|
||||
|
||||
return &WebSocketPacket{
|
||||
Type: packet.Type,
|
||||
Data: data,
|
||||
ErrorMessage: errorMessage,
|
||||
}
|
||||
}
|
||||
|
||||
// WebSocketPacketType constants from C# example
|
||||
const (
|
||||
WebSocketPacketTypePing = "ping"
|
||||
WebSocketPacketTypePong = "pong"
|
||||
WebSocketPacketTypeError = "error"
|
||||
)
|
||||
Reference in New Issue
Block a user