🎉 Initial Commit of Ring
This commit is contained in:
292
pkg/ring/websocket/manager.go
Normal file
292
pkg/ring/websocket/manager.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/gofiber/contrib/v3/websocket"
|
||||
"github.com/google/uuid"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/rs/zerolog/log"
|
||||
"google.golang.org/protobuf/proto" // Import for proto.Marshal
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
|
||||
pb "git.solsynth.dev/goatworks/turbine/pkg/shared/proto/gen"
|
||||
)
|
||||
|
||||
// ConnectionKey represents the unique identifier for a WebSocket connection.
|
||||
type ConnectionKey struct {
|
||||
AccountID uuid.UUID
|
||||
DeviceID string
|
||||
}
|
||||
|
||||
// ConnectionState holds the WebSocket connection and its cancellation context.
|
||||
type ConnectionState struct {
|
||||
Conn *websocket.Conn
|
||||
Cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// Manager manages active WebSocket connections.
|
||||
type Manager struct {
|
||||
connections sync.Map // Map[ConnectionKey]*ConnectionState
|
||||
natsClient *nats.Conn
|
||||
}
|
||||
|
||||
// NewManager creates a new WebSocket Manager.
|
||||
func NewManager(natsClient *nats.Conn) *Manager {
|
||||
return &Manager{
|
||||
natsClient: natsClient,
|
||||
}
|
||||
}
|
||||
|
||||
// TryAdd attempts to add a new connection. If a connection with the same key exists,
|
||||
// it disconnects the old one and then adds the new one.
|
||||
func (m *Manager) TryAdd(key ConnectionKey, conn *websocket.Conn, cancel context.CancelFunc) bool {
|
||||
// Disconnect existing connection with the same identifier, if any
|
||||
if _, loaded := m.connections.Load(key); loaded {
|
||||
log.Warn().Msgf("Duplicate connection detected for %s:%s. Disconnecting old one.", key.AccountID, key.DeviceID)
|
||||
m.Disconnect(key, "Just connected somewhere else with the same identifier.")
|
||||
}
|
||||
|
||||
m.connections.Store(key, &ConnectionState{Conn: conn, Cancel: cancel})
|
||||
log.Info().Msgf("Connection established for user %s and device %s", key.AccountID, key.DeviceID)
|
||||
return true
|
||||
}
|
||||
|
||||
// Disconnect removes a connection and closes it.
|
||||
func (m *Manager) Disconnect(key ConnectionKey, reason string) {
|
||||
if stateAny, loaded := m.connections.LoadAndDelete(key); loaded {
|
||||
state := stateAny.(ConnectionState)
|
||||
|
||||
// Cancel the context to stop any goroutines associated with this connection
|
||||
if state.Cancel != nil {
|
||||
state.Cancel()
|
||||
}
|
||||
|
||||
// Close the WebSocket connection
|
||||
err := state.Conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, reason))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Error sending close message to WebSocket for %s:%s", key.AccountID, key.DeviceID)
|
||||
}
|
||||
err = state.Conn.Close() // Ensure the connection is closed
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Error closing WebSocket for %s:%s", key.AccountID, key.DeviceID)
|
||||
}
|
||||
|
||||
log.Info().Msgf("Connection disconnected for user %s and device %s. Reason: %s", key.AccountID, key.DeviceID, reason)
|
||||
}
|
||||
}
|
||||
|
||||
// GetDeviceIsConnected checks if any connection exists for a given device ID.
|
||||
func (m *Manager) GetDeviceIsConnected(deviceID string) bool {
|
||||
var isConnected bool
|
||||
m.connections.Range(func(k, v interface{}) bool {
|
||||
connKey := k.(ConnectionKey)
|
||||
if connKey.DeviceID == deviceID {
|
||||
isConnected = true
|
||||
return false // Stop iteration
|
||||
}
|
||||
return true
|
||||
})
|
||||
return isConnected
|
||||
}
|
||||
|
||||
// GetAccountIsConnected checks if any connection exists for a given account ID.
|
||||
func (m *Manager) GetAccountIsConnected(accountID uuid.UUID) bool {
|
||||
var isConnected bool
|
||||
m.connections.Range(func(k, v interface{}) bool {
|
||||
connKey := k.(ConnectionKey)
|
||||
if connKey.AccountID == accountID {
|
||||
isConnected = true
|
||||
return false // Stop iteration
|
||||
}
|
||||
return true
|
||||
})
|
||||
return isConnected
|
||||
}
|
||||
|
||||
// SendPacketToAccount sends a WebSocketPacket to all connections for a given account ID.
|
||||
func (m *Manager) SendPacketToAccount(accountID uuid.UUID, packet *WebSocketPacket) {
|
||||
packetBytes, err := packet.ToBytes()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to marshal packet for account %s", accountID)
|
||||
return
|
||||
}
|
||||
|
||||
m.connections.Range(func(k, v interface{}) bool {
|
||||
connKey := k.(ConnectionKey)
|
||||
if connKey.AccountID == accountID {
|
||||
state := v.(*ConnectionState)
|
||||
if state.Conn != nil {
|
||||
err := state.Conn.WriteMessage(websocket.BinaryMessage, packetBytes)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to send packet to account %s, device %s", accountID, connKey.DeviceID)
|
||||
// Optionally, disconnect this problematic connection
|
||||
// m.Disconnect(connKey, "Failed to send packet")
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// SendPacketToDevice sends a WebSocketPacket to all connections for a given device ID.
|
||||
func (m *Manager) SendPacketToDevice(deviceID string, packet *WebSocketPacket) {
|
||||
packetBytes, err := packet.ToBytes()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to marshal packet for device %s", deviceID)
|
||||
return
|
||||
}
|
||||
|
||||
m.connections.Range(func(k, v interface{}) bool {
|
||||
connKey := k.(ConnectionKey)
|
||||
if connKey.DeviceID == deviceID {
|
||||
state := v.(*ConnectionState)
|
||||
if state.Conn != nil {
|
||||
err := state.Conn.WriteMessage(websocket.BinaryMessage, packetBytes)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to send packet to device %s, account %s", deviceID, connKey.AccountID)
|
||||
// Optionally, disconnect this problematic connection
|
||||
// m.Disconnect(connKey, "Failed to send packet")
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// PublishWebSocketConnectedEvent publishes a WebSocketConnectedEvent to NATS.
|
||||
func (m *Manager) PublishWebSocketConnectedEvent(accountID uuid.UUID, deviceID string, isOffline bool) {
|
||||
if m.natsClient == nil {
|
||||
log.Warn().Msg("NATS client not initialized. Cannot publish WebSocketConnectedEvent.")
|
||||
return
|
||||
}
|
||||
|
||||
event := &pb.WebSocketConnectedEvent{
|
||||
AccountId: wrapperspb.String(accountID.String()),
|
||||
DeviceId: deviceID,
|
||||
IsOffline: isOffline,
|
||||
}
|
||||
eventBytes, err := proto.Marshal(event)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal WebSocketConnectedEvent")
|
||||
return
|
||||
}
|
||||
|
||||
// Assuming WebSocketConnectedEvent.Type is a constant or can be derived
|
||||
// For now, let's use a hardcoded subject. This needs to be consistent with C# code.
|
||||
// C# uses `WebSocketConnectedEvent.Type`
|
||||
// TODO: Define NATS subjects in a centralized place or derive from proto events.
|
||||
natsSubject := "turbine.websocket.connected"
|
||||
err = m.natsClient.Publish(natsSubject, eventBytes)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to publish WebSocketConnectedEvent to NATS")
|
||||
} else {
|
||||
log.Info().Msgf("Published WebSocketConnectedEvent for %s:%s to NATS subject %s", accountID, deviceID, natsSubject)
|
||||
}
|
||||
}
|
||||
|
||||
// PublishWebSocketDisconnectedEvent publishes a WebSocketDisconnectedEvent to NATS.
|
||||
func (m *Manager) PublishWebSocketDisconnectedEvent(accountID uuid.UUID, deviceID string, isOffline bool) {
|
||||
if m.natsClient == nil {
|
||||
log.Warn().Msg("NATS client not initialized. Cannot publish WebSocketDisconnectedEvent.")
|
||||
return
|
||||
}
|
||||
|
||||
event := &pb.WebSocketDisconnectedEvent{
|
||||
AccountId: wrapperspb.String(accountID.String()),
|
||||
DeviceId: deviceID,
|
||||
IsOffline: isOffline,
|
||||
}
|
||||
eventBytes, err := proto.Marshal(event)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal WebSocketDisconnectedEvent")
|
||||
return
|
||||
}
|
||||
|
||||
// Assuming WebSocketDisconnectedEvent.Type is a constant or can be derived
|
||||
natsSubject := "turbine.websocket.disconnected"
|
||||
err = m.natsClient.Publish(natsSubject, eventBytes)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to publish WebSocketDisconnectedEvent to NATS")
|
||||
} else {
|
||||
log.Info().Msgf("Published WebSocketDisconnectedEvent for %s:%s to NATS subject %s", accountID, deviceID, natsSubject)
|
||||
}
|
||||
}
|
||||
|
||||
// HandlePacket processes incoming WebSocketPacket.
|
||||
func (m *Manager) HandlePacket(currentUser *pb.Account, deviceID string, packet *WebSocketPacket, conn *websocket.Conn) error {
|
||||
switch packet.Type {
|
||||
case WebSocketPacketTypePing:
|
||||
pongPacket := &WebSocketPacket{Type: WebSocketPacketTypePong}
|
||||
pongBytes, err := pongPacket.ToBytes()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal pong packet")
|
||||
return err
|
||||
}
|
||||
return conn.WriteMessage(websocket.BinaryMessage, pongBytes)
|
||||
case WebSocketPacketTypeError:
|
||||
log.Error().Msgf("Received error packet from device %s: %s", deviceID, packet.ErrorMessage)
|
||||
return nil // Or handle error appropriately
|
||||
default:
|
||||
if packet.Endpoint != "" {
|
||||
return m.forwardPacketToNATS(currentUser, deviceID, packet)
|
||||
} else {
|
||||
errorPacket := &WebSocketPacket{
|
||||
Type: WebSocketPacketTypeError,
|
||||
ErrorMessage: fmt.Sprintf("Unprocessable packet: %s", packet.Type),
|
||||
}
|
||||
errorBytes, err := errorPacket.ToBytes()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal error packet")
|
||||
return err
|
||||
}
|
||||
return conn.WriteMessage(websocket.BinaryMessage, errorBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) forwardPacketToNATS(currentUser *pb.Account, deviceID string, packet *WebSocketPacket) error {
|
||||
if m.natsClient == nil {
|
||||
log.Warn().Msg("NATS client not initialized. Cannot forward packets to NATS.")
|
||||
return fmt.Errorf("NATS client not initialized")
|
||||
}
|
||||
|
||||
packetBytes, err := packet.ToBytes()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal WebSocketPacket for NATS forwarding")
|
||||
return err
|
||||
}
|
||||
|
||||
// Convert currentUser.Id to string for the proto message
|
||||
var accountIDStr string
|
||||
if currentUser != nil && currentUser.Id != "" {
|
||||
accountIDStr = currentUser.Id
|
||||
} else {
|
||||
log.Warn().Msg("CurrentUser or CurrentUser.Id is nil/empty for NATS forwarding")
|
||||
return fmt.Errorf("current user ID is missing")
|
||||
}
|
||||
|
||||
event := &pb.WebSocketPacketEvent{
|
||||
AccountId: wrapperspb.String(accountIDStr),
|
||||
DeviceId: deviceID,
|
||||
PacketBytes: packetBytes,
|
||||
}
|
||||
eventBytes, err := proto.Marshal(event)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal WebSocketPacketEvent for NATS")
|
||||
return err
|
||||
}
|
||||
|
||||
// C# uses WebSocketPacketEvent.SubjectPrefix + endpoint
|
||||
// TODO: Centralize NATS subject definitions
|
||||
natsSubject := fmt.Sprintf("turbine.websocket.packet.%s", packet.Endpoint)
|
||||
err = m.natsClient.Publish(natsSubject, eventBytes)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to publish WebSocketPacketEvent to NATS")
|
||||
return err
|
||||
}
|
||||
log.Info().Msgf("Forwarded packet to NATS subject %s from device %s", natsSubject, deviceID)
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user