package websocket import ( "context" "fmt" "sync" "github.com/gofiber/contrib/v3/websocket" "github.com/google/uuid" "github.com/nats-io/nats.go" "github.com/rs/zerolog/log" "google.golang.org/protobuf/proto" // Import for proto.Marshal "google.golang.org/protobuf/types/known/wrapperspb" pb "git.solsynth.dev/goatworks/turbine/pkg/shared/proto/gen" ) // ConnectionKey represents the unique identifier for a WebSocket connection. type ConnectionKey struct { AccountID uuid.UUID DeviceID string } // ConnectionState holds the WebSocket connection and its cancellation context. type ConnectionState struct { Conn *websocket.Conn Cancel context.CancelFunc } // Manager manages active WebSocket connections. type Manager struct { connections sync.Map // Map[ConnectionKey]*ConnectionState natsClient *nats.Conn } // NewManager creates a new WebSocket Manager. func NewManager(natsClient *nats.Conn) *Manager { return &Manager{ natsClient: natsClient, } } // TryAdd attempts to add a new connection. If a connection with the same key exists, // it disconnects the old one and then adds the new one. func (m *Manager) TryAdd(key ConnectionKey, conn *websocket.Conn, cancel context.CancelFunc) bool { // Disconnect existing connection with the same identifier, if any if _, loaded := m.connections.Load(key); loaded { log.Warn().Msgf("Duplicate connection detected for %s:%s. Disconnecting old one.", key.AccountID, key.DeviceID) m.Disconnect(key, "Just connected somewhere else with the same identifier.") } m.connections.Store(key, &ConnectionState{Conn: conn, Cancel: cancel}) log.Info().Msgf("Connection established for user %s and device %s", key.AccountID, key.DeviceID) return true } // Disconnect removes a connection and closes it. func (m *Manager) Disconnect(key ConnectionKey, reason string) { if stateAny, loaded := m.connections.LoadAndDelete(key); loaded { state := stateAny.(ConnectionState) // Cancel the context to stop any goroutines associated with this connection if state.Cancel != nil { state.Cancel() } // Close the WebSocket connection err := state.Conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, reason)) if err != nil { log.Error().Err(err).Msgf("Error sending close message to WebSocket for %s:%s", key.AccountID, key.DeviceID) } err = state.Conn.Close() // Ensure the connection is closed if err != nil { log.Error().Err(err).Msgf("Error closing WebSocket for %s:%s", key.AccountID, key.DeviceID) } log.Info().Msgf("Connection disconnected for user %s and device %s. Reason: %s", key.AccountID, key.DeviceID, reason) } } // GetDeviceIsConnected checks if any connection exists for a given device ID. func (m *Manager) GetDeviceIsConnected(deviceID string) bool { var isConnected bool m.connections.Range(func(k, v interface{}) bool { connKey := k.(ConnectionKey) if connKey.DeviceID == deviceID { isConnected = true return false // Stop iteration } return true }) return isConnected } // GetAccountIsConnected checks if any connection exists for a given account ID. func (m *Manager) GetAccountIsConnected(accountID uuid.UUID) bool { var isConnected bool m.connections.Range(func(k, v interface{}) bool { connKey := k.(ConnectionKey) if connKey.AccountID == accountID { isConnected = true return false // Stop iteration } return true }) return isConnected } // SendPacketToAccount sends a WebSocketPacket to all connections for a given account ID. func (m *Manager) SendPacketToAccount(accountID uuid.UUID, packet *WebSocketPacket) { packetBytes, err := packet.ToBytes() if err != nil { log.Error().Err(err).Msgf("Failed to marshal packet for account %s", accountID) return } m.connections.Range(func(k, v interface{}) bool { connKey := k.(ConnectionKey) if connKey.AccountID == accountID { state := v.(*ConnectionState) if state.Conn != nil { err := state.Conn.WriteMessage(websocket.BinaryMessage, packetBytes) if err != nil { log.Error().Err(err).Msgf("Failed to send packet to account %s, device %s", accountID, connKey.DeviceID) // Optionally, disconnect this problematic connection // m.Disconnect(connKey, "Failed to send packet") } } } return true }) } // SendPacketToDevice sends a WebSocketPacket to all connections for a given device ID. func (m *Manager) SendPacketToDevice(deviceID string, packet *WebSocketPacket) { packetBytes, err := packet.ToBytes() if err != nil { log.Error().Err(err).Msgf("Failed to marshal packet for device %s", deviceID) return } m.connections.Range(func(k, v interface{}) bool { connKey := k.(ConnectionKey) if connKey.DeviceID == deviceID { state := v.(*ConnectionState) if state.Conn != nil { err := state.Conn.WriteMessage(websocket.BinaryMessage, packetBytes) if err != nil { log.Error().Err(err).Msgf("Failed to send packet to device %s, account %s", deviceID, connKey.AccountID) // Optionally, disconnect this problematic connection // m.Disconnect(connKey, "Failed to send packet") } } } return true }) } // PublishWebSocketConnectedEvent publishes a WebSocketConnectedEvent to NATS. func (m *Manager) PublishWebSocketConnectedEvent(accountID uuid.UUID, deviceID string, isOffline bool) { if m.natsClient == nil { log.Warn().Msg("NATS client not initialized. Cannot publish WebSocketConnectedEvent.") return } event := &pb.WebSocketConnectedEvent{ AccountId: wrapperspb.String(accountID.String()), DeviceId: deviceID, IsOffline: isOffline, } eventBytes, err := proto.Marshal(event) if err != nil { log.Error().Err(err).Msg("Failed to marshal WebSocketConnectedEvent") return } // Assuming WebSocketConnectedEvent.Type is a constant or can be derived // For now, let's use a hardcoded subject. This needs to be consistent with C# code. // C# uses `WebSocketConnectedEvent.Type` // TODO: Define NATS subjects in a centralized place or derive from proto events. natsSubject := "turbine.websocket.connected" err = m.natsClient.Publish(natsSubject, eventBytes) if err != nil { log.Error().Err(err).Msg("Failed to publish WebSocketConnectedEvent to NATS") } else { log.Info().Msgf("Published WebSocketConnectedEvent for %s:%s to NATS subject %s", accountID, deviceID, natsSubject) } } // PublishWebSocketDisconnectedEvent publishes a WebSocketDisconnectedEvent to NATS. func (m *Manager) PublishWebSocketDisconnectedEvent(accountID uuid.UUID, deviceID string, isOffline bool) { if m.natsClient == nil { log.Warn().Msg("NATS client not initialized. Cannot publish WebSocketDisconnectedEvent.") return } event := &pb.WebSocketDisconnectedEvent{ AccountId: wrapperspb.String(accountID.String()), DeviceId: deviceID, IsOffline: isOffline, } eventBytes, err := proto.Marshal(event) if err != nil { log.Error().Err(err).Msg("Failed to marshal WebSocketDisconnectedEvent") return } // Assuming WebSocketDisconnectedEvent.Type is a constant or can be derived natsSubject := "turbine.websocket.disconnected" err = m.natsClient.Publish(natsSubject, eventBytes) if err != nil { log.Error().Err(err).Msg("Failed to publish WebSocketDisconnectedEvent to NATS") } else { log.Info().Msgf("Published WebSocketDisconnectedEvent for %s:%s to NATS subject %s", accountID, deviceID, natsSubject) } } // HandlePacket processes incoming WebSocketPacket. func (m *Manager) HandlePacket(currentUser *pb.Account, deviceID string, packet *WebSocketPacket, conn *websocket.Conn) error { switch packet.Type { case WebSocketPacketTypePing: pongPacket := &WebSocketPacket{Type: WebSocketPacketTypePong} pongBytes, err := pongPacket.ToBytes() if err != nil { log.Error().Err(err).Msg("Failed to marshal pong packet") return err } return conn.WriteMessage(websocket.BinaryMessage, pongBytes) case WebSocketPacketTypeError: log.Error().Msgf("Received error packet from device %s: %s", deviceID, packet.ErrorMessage) return nil // Or handle error appropriately default: if packet.Endpoint != "" { return m.forwardPacketToNATS(currentUser, deviceID, packet) } else { errorPacket := &WebSocketPacket{ Type: WebSocketPacketTypeError, ErrorMessage: fmt.Sprintf("Unprocessable packet: %s", packet.Type), } errorBytes, err := errorPacket.ToBytes() if err != nil { log.Error().Err(err).Msg("Failed to marshal error packet") return err } return conn.WriteMessage(websocket.BinaryMessage, errorBytes) } } } func (m *Manager) forwardPacketToNATS(currentUser *pb.Account, deviceID string, packet *WebSocketPacket) error { if m.natsClient == nil { log.Warn().Msg("NATS client not initialized. Cannot forward packets to NATS.") return fmt.Errorf("NATS client not initialized") } packetBytes, err := packet.ToBytes() if err != nil { log.Error().Err(err).Msg("Failed to marshal WebSocketPacket for NATS forwarding") return err } // Convert currentUser.Id to string for the proto message var accountIDStr string if currentUser != nil && currentUser.Id != "" { accountIDStr = currentUser.Id } else { log.Warn().Msg("CurrentUser or CurrentUser.Id is nil/empty for NATS forwarding") return fmt.Errorf("current user ID is missing") } event := &pb.WebSocketPacketEvent{ AccountId: wrapperspb.String(accountIDStr), DeviceId: deviceID, PacketBytes: packetBytes, } eventBytes, err := proto.Marshal(event) if err != nil { log.Error().Err(err).Msg("Failed to marshal WebSocketPacketEvent for NATS") return err } // C# uses WebSocketPacketEvent.SubjectPrefix + endpoint // TODO: Centralize NATS subject definitions natsSubject := fmt.Sprintf("turbine.websocket.packet.%s", packet.Endpoint) err = m.natsClient.Publish(natsSubject, eventBytes) if err != nil { log.Error().Err(err).Msg("Failed to publish WebSocketPacketEvent to NATS") return err } log.Info().Msgf("Forwarded packet to NATS subject %s from device %s", natsSubject, deviceID) return nil }