🎉 Initial Commit of Ring
This commit is contained in:
144
pkg/ring/websocket/controller.go
Normal file
144
pkg/ring/websocket/controller.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gofiber/contrib/v3/websocket"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
|
||||
pb "git.solsynth.dev/goatworks/turbine/pkg/shared/proto/gen"
|
||||
)
|
||||
|
||||
// WebSocketController handles the WebSocket endpoint.
|
||||
type WebSocketController struct {
|
||||
Manager *Manager
|
||||
}
|
||||
|
||||
// NewWebSocketController creates a new WebSocketController.
|
||||
func NewWebSocketController(manager *Manager) *WebSocketController {
|
||||
return &WebSocketController{
|
||||
Manager: manager,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleWebSocket is the main handler for the /ws endpoint.
|
||||
func (wc *WebSocketController) HandleWebSocket(c *websocket.Conn) {
|
||||
// Mock Authentication for now based on C# example
|
||||
// In a real scenario, this would involve JWT verification, session lookup, etc.
|
||||
// For demonstration, we'll assume a dummy user and session.
|
||||
// The C# code uses HttpContext.Items to get CurrentUser and CurrentSession.
|
||||
// In Go Fiber, we can pass this through Locals or use middleware to set it up.
|
||||
// For now, let's create a dummy user and session.
|
||||
|
||||
// TODO: Replace with actual authentication logic
|
||||
// For now, assume a dummy account and session
|
||||
// Based on the C# code, CurrentUser and CurrentSession are pb.Account and pb.AuthSession
|
||||
dummyAccount := &pb.Account{
|
||||
Id: uuid.New().String(),
|
||||
Name: "dummy_user",
|
||||
Nick: "Dummy User",
|
||||
}
|
||||
dummySession := &pb.AuthSession{
|
||||
ClientId: lo.ToPtr(uuid.New().String()), // This is used as deviceId if not present
|
||||
}
|
||||
|
||||
// Device ID handling
|
||||
deviceAlt := c.Query("deviceAlt")
|
||||
if deviceAlt != "" {
|
||||
allowedDeviceAlternative := []string{"watch"} // Hardcoded for now
|
||||
found := false
|
||||
for _, alt := range allowedDeviceAlternative {
|
||||
if deviceAlt == alt {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
log.Warn().Msgf("Unsupported device alternative: %s", deviceAlt)
|
||||
bytes, _ := newErrorPacket("Unsupported device alternative: " + deviceAlt).ToBytes()
|
||||
c.WriteMessage(websocket.BinaryMessage, bytes)
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
accountID := uuid.MustParse(dummyAccount.Id)
|
||||
deviceIDStr := ""
|
||||
if dummySession.ClientId == nil {
|
||||
deviceIDStr = uuid.New().String()
|
||||
} else {
|
||||
deviceIDStr = *dummySession.ClientId
|
||||
}
|
||||
if deviceAlt != "" {
|
||||
deviceIDStr = fmt.Sprintf("%s+%s", deviceIDStr, deviceAlt)
|
||||
}
|
||||
|
||||
// Setup connection context for cancellation
|
||||
cancel := func() {} // Placeholder
|
||||
connectionKey := ConnectionKey{AccountID: accountID, DeviceID: deviceIDStr}
|
||||
|
||||
// Add connection to manager
|
||||
if !wc.Manager.TryAdd(connectionKey, c, cancel) {
|
||||
bytes, _ := newErrorPacket("Too many connections from the same device and account.").ToBytes()
|
||||
c.WriteMessage(websocket.BinaryMessage, bytes)
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Connection established with user @%s#%s and device #%s", dummyAccount.Name, dummyAccount.Id, deviceIDStr)
|
||||
|
||||
// Publish WebSocket connected event
|
||||
wc.Manager.PublishWebSocketConnectedEvent(accountID, deviceIDStr, false) // isOffline is false on connection
|
||||
|
||||
defer func() {
|
||||
wc.Manager.Disconnect(connectionKey, "Client disconnected.")
|
||||
// Publish WebSocket disconnected event
|
||||
isOffline := !wc.Manager.GetAccountIsConnected(accountID) // Check if account is completely offline
|
||||
wc.Manager.PublishWebSocketDisconnectedEvent(accountID, deviceIDStr, isOffline)
|
||||
log.Debug().Msgf("Connection disconnected with user @%s#%s and device #%s", dummyAccount.Name, dummyAccount.Id, deviceIDStr)
|
||||
}()
|
||||
|
||||
// Main event loop
|
||||
for {
|
||||
mt, msg, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("WebSocket read error")
|
||||
break
|
||||
}
|
||||
|
||||
if mt == websocket.CloseMessage {
|
||||
log.Info().Msg("Received close message from client")
|
||||
break
|
||||
}
|
||||
|
||||
if mt != websocket.BinaryMessage {
|
||||
log.Warn().Msgf("Received non-binary message type: %d", mt)
|
||||
continue
|
||||
}
|
||||
|
||||
packet, err := FromBytes(msg)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to deserialize WebSocket packet")
|
||||
bytes, _ := newErrorPacket("Failed to deserialize packet").ToBytes()
|
||||
c.WriteMessage(websocket.BinaryMessage, bytes)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := wc.Manager.HandlePacket(dummyAccount, deviceIDStr, packet, c); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to handle incoming WebSocket packet")
|
||||
bytes, _ := newErrorPacket("Failed to process packet").ToBytes()
|
||||
c.WriteMessage(websocket.BinaryMessage, bytes)
|
||||
// Depending on error, might want to close connection
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newErrorPacket creates a new WebSocketPacket with an error type.
|
||||
func newErrorPacket(message string) *WebSocketPacket {
|
||||
return &WebSocketPacket{
|
||||
Type: WebSocketPacketTypeError,
|
||||
ErrorMessage: message,
|
||||
}
|
||||
}
|
||||
292
pkg/ring/websocket/manager.go
Normal file
292
pkg/ring/websocket/manager.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/gofiber/contrib/v3/websocket"
|
||||
"github.com/google/uuid"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/rs/zerolog/log"
|
||||
"google.golang.org/protobuf/proto" // Import for proto.Marshal
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
|
||||
pb "git.solsynth.dev/goatworks/turbine/pkg/shared/proto/gen"
|
||||
)
|
||||
|
||||
// ConnectionKey represents the unique identifier for a WebSocket connection.
|
||||
type ConnectionKey struct {
|
||||
AccountID uuid.UUID
|
||||
DeviceID string
|
||||
}
|
||||
|
||||
// ConnectionState holds the WebSocket connection and its cancellation context.
|
||||
type ConnectionState struct {
|
||||
Conn *websocket.Conn
|
||||
Cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// Manager manages active WebSocket connections.
|
||||
type Manager struct {
|
||||
connections sync.Map // Map[ConnectionKey]*ConnectionState
|
||||
natsClient *nats.Conn
|
||||
}
|
||||
|
||||
// NewManager creates a new WebSocket Manager.
|
||||
func NewManager(natsClient *nats.Conn) *Manager {
|
||||
return &Manager{
|
||||
natsClient: natsClient,
|
||||
}
|
||||
}
|
||||
|
||||
// TryAdd attempts to add a new connection. If a connection with the same key exists,
|
||||
// it disconnects the old one and then adds the new one.
|
||||
func (m *Manager) TryAdd(key ConnectionKey, conn *websocket.Conn, cancel context.CancelFunc) bool {
|
||||
// Disconnect existing connection with the same identifier, if any
|
||||
if _, loaded := m.connections.Load(key); loaded {
|
||||
log.Warn().Msgf("Duplicate connection detected for %s:%s. Disconnecting old one.", key.AccountID, key.DeviceID)
|
||||
m.Disconnect(key, "Just connected somewhere else with the same identifier.")
|
||||
}
|
||||
|
||||
m.connections.Store(key, &ConnectionState{Conn: conn, Cancel: cancel})
|
||||
log.Info().Msgf("Connection established for user %s and device %s", key.AccountID, key.DeviceID)
|
||||
return true
|
||||
}
|
||||
|
||||
// Disconnect removes a connection and closes it.
|
||||
func (m *Manager) Disconnect(key ConnectionKey, reason string) {
|
||||
if stateAny, loaded := m.connections.LoadAndDelete(key); loaded {
|
||||
state := stateAny.(ConnectionState)
|
||||
|
||||
// Cancel the context to stop any goroutines associated with this connection
|
||||
if state.Cancel != nil {
|
||||
state.Cancel()
|
||||
}
|
||||
|
||||
// Close the WebSocket connection
|
||||
err := state.Conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, reason))
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Error sending close message to WebSocket for %s:%s", key.AccountID, key.DeviceID)
|
||||
}
|
||||
err = state.Conn.Close() // Ensure the connection is closed
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Error closing WebSocket for %s:%s", key.AccountID, key.DeviceID)
|
||||
}
|
||||
|
||||
log.Info().Msgf("Connection disconnected for user %s and device %s. Reason: %s", key.AccountID, key.DeviceID, reason)
|
||||
}
|
||||
}
|
||||
|
||||
// GetDeviceIsConnected checks if any connection exists for a given device ID.
|
||||
func (m *Manager) GetDeviceIsConnected(deviceID string) bool {
|
||||
var isConnected bool
|
||||
m.connections.Range(func(k, v interface{}) bool {
|
||||
connKey := k.(ConnectionKey)
|
||||
if connKey.DeviceID == deviceID {
|
||||
isConnected = true
|
||||
return false // Stop iteration
|
||||
}
|
||||
return true
|
||||
})
|
||||
return isConnected
|
||||
}
|
||||
|
||||
// GetAccountIsConnected checks if any connection exists for a given account ID.
|
||||
func (m *Manager) GetAccountIsConnected(accountID uuid.UUID) bool {
|
||||
var isConnected bool
|
||||
m.connections.Range(func(k, v interface{}) bool {
|
||||
connKey := k.(ConnectionKey)
|
||||
if connKey.AccountID == accountID {
|
||||
isConnected = true
|
||||
return false // Stop iteration
|
||||
}
|
||||
return true
|
||||
})
|
||||
return isConnected
|
||||
}
|
||||
|
||||
// SendPacketToAccount sends a WebSocketPacket to all connections for a given account ID.
|
||||
func (m *Manager) SendPacketToAccount(accountID uuid.UUID, packet *WebSocketPacket) {
|
||||
packetBytes, err := packet.ToBytes()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to marshal packet for account %s", accountID)
|
||||
return
|
||||
}
|
||||
|
||||
m.connections.Range(func(k, v interface{}) bool {
|
||||
connKey := k.(ConnectionKey)
|
||||
if connKey.AccountID == accountID {
|
||||
state := v.(*ConnectionState)
|
||||
if state.Conn != nil {
|
||||
err := state.Conn.WriteMessage(websocket.BinaryMessage, packetBytes)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to send packet to account %s, device %s", accountID, connKey.DeviceID)
|
||||
// Optionally, disconnect this problematic connection
|
||||
// m.Disconnect(connKey, "Failed to send packet")
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// SendPacketToDevice sends a WebSocketPacket to all connections for a given device ID.
|
||||
func (m *Manager) SendPacketToDevice(deviceID string, packet *WebSocketPacket) {
|
||||
packetBytes, err := packet.ToBytes()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to marshal packet for device %s", deviceID)
|
||||
return
|
||||
}
|
||||
|
||||
m.connections.Range(func(k, v interface{}) bool {
|
||||
connKey := k.(ConnectionKey)
|
||||
if connKey.DeviceID == deviceID {
|
||||
state := v.(*ConnectionState)
|
||||
if state.Conn != nil {
|
||||
err := state.Conn.WriteMessage(websocket.BinaryMessage, packetBytes)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to send packet to device %s, account %s", deviceID, connKey.AccountID)
|
||||
// Optionally, disconnect this problematic connection
|
||||
// m.Disconnect(connKey, "Failed to send packet")
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// PublishWebSocketConnectedEvent publishes a WebSocketConnectedEvent to NATS.
|
||||
func (m *Manager) PublishWebSocketConnectedEvent(accountID uuid.UUID, deviceID string, isOffline bool) {
|
||||
if m.natsClient == nil {
|
||||
log.Warn().Msg("NATS client not initialized. Cannot publish WebSocketConnectedEvent.")
|
||||
return
|
||||
}
|
||||
|
||||
event := &pb.WebSocketConnectedEvent{
|
||||
AccountId: wrapperspb.String(accountID.String()),
|
||||
DeviceId: deviceID,
|
||||
IsOffline: isOffline,
|
||||
}
|
||||
eventBytes, err := proto.Marshal(event)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal WebSocketConnectedEvent")
|
||||
return
|
||||
}
|
||||
|
||||
// Assuming WebSocketConnectedEvent.Type is a constant or can be derived
|
||||
// For now, let's use a hardcoded subject. This needs to be consistent with C# code.
|
||||
// C# uses `WebSocketConnectedEvent.Type`
|
||||
// TODO: Define NATS subjects in a centralized place or derive from proto events.
|
||||
natsSubject := "turbine.websocket.connected"
|
||||
err = m.natsClient.Publish(natsSubject, eventBytes)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to publish WebSocketConnectedEvent to NATS")
|
||||
} else {
|
||||
log.Info().Msgf("Published WebSocketConnectedEvent for %s:%s to NATS subject %s", accountID, deviceID, natsSubject)
|
||||
}
|
||||
}
|
||||
|
||||
// PublishWebSocketDisconnectedEvent publishes a WebSocketDisconnectedEvent to NATS.
|
||||
func (m *Manager) PublishWebSocketDisconnectedEvent(accountID uuid.UUID, deviceID string, isOffline bool) {
|
||||
if m.natsClient == nil {
|
||||
log.Warn().Msg("NATS client not initialized. Cannot publish WebSocketDisconnectedEvent.")
|
||||
return
|
||||
}
|
||||
|
||||
event := &pb.WebSocketDisconnectedEvent{
|
||||
AccountId: wrapperspb.String(accountID.String()),
|
||||
DeviceId: deviceID,
|
||||
IsOffline: isOffline,
|
||||
}
|
||||
eventBytes, err := proto.Marshal(event)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal WebSocketDisconnectedEvent")
|
||||
return
|
||||
}
|
||||
|
||||
// Assuming WebSocketDisconnectedEvent.Type is a constant or can be derived
|
||||
natsSubject := "turbine.websocket.disconnected"
|
||||
err = m.natsClient.Publish(natsSubject, eventBytes)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to publish WebSocketDisconnectedEvent to NATS")
|
||||
} else {
|
||||
log.Info().Msgf("Published WebSocketDisconnectedEvent for %s:%s to NATS subject %s", accountID, deviceID, natsSubject)
|
||||
}
|
||||
}
|
||||
|
||||
// HandlePacket processes incoming WebSocketPacket.
|
||||
func (m *Manager) HandlePacket(currentUser *pb.Account, deviceID string, packet *WebSocketPacket, conn *websocket.Conn) error {
|
||||
switch packet.Type {
|
||||
case WebSocketPacketTypePing:
|
||||
pongPacket := &WebSocketPacket{Type: WebSocketPacketTypePong}
|
||||
pongBytes, err := pongPacket.ToBytes()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal pong packet")
|
||||
return err
|
||||
}
|
||||
return conn.WriteMessage(websocket.BinaryMessage, pongBytes)
|
||||
case WebSocketPacketTypeError:
|
||||
log.Error().Msgf("Received error packet from device %s: %s", deviceID, packet.ErrorMessage)
|
||||
return nil // Or handle error appropriately
|
||||
default:
|
||||
if packet.Endpoint != "" {
|
||||
return m.forwardPacketToNATS(currentUser, deviceID, packet)
|
||||
} else {
|
||||
errorPacket := &WebSocketPacket{
|
||||
Type: WebSocketPacketTypeError,
|
||||
ErrorMessage: fmt.Sprintf("Unprocessable packet: %s", packet.Type),
|
||||
}
|
||||
errorBytes, err := errorPacket.ToBytes()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal error packet")
|
||||
return err
|
||||
}
|
||||
return conn.WriteMessage(websocket.BinaryMessage, errorBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) forwardPacketToNATS(currentUser *pb.Account, deviceID string, packet *WebSocketPacket) error {
|
||||
if m.natsClient == nil {
|
||||
log.Warn().Msg("NATS client not initialized. Cannot forward packets to NATS.")
|
||||
return fmt.Errorf("NATS client not initialized")
|
||||
}
|
||||
|
||||
packetBytes, err := packet.ToBytes()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal WebSocketPacket for NATS forwarding")
|
||||
return err
|
||||
}
|
||||
|
||||
// Convert currentUser.Id to string for the proto message
|
||||
var accountIDStr string
|
||||
if currentUser != nil && currentUser.Id != "" {
|
||||
accountIDStr = currentUser.Id
|
||||
} else {
|
||||
log.Warn().Msg("CurrentUser or CurrentUser.Id is nil/empty for NATS forwarding")
|
||||
return fmt.Errorf("current user ID is missing")
|
||||
}
|
||||
|
||||
event := &pb.WebSocketPacketEvent{
|
||||
AccountId: wrapperspb.String(accountIDStr),
|
||||
DeviceId: deviceID,
|
||||
PacketBytes: packetBytes,
|
||||
}
|
||||
eventBytes, err := proto.Marshal(event)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to marshal WebSocketPacketEvent for NATS")
|
||||
return err
|
||||
}
|
||||
|
||||
// C# uses WebSocketPacketEvent.SubjectPrefix + endpoint
|
||||
// TODO: Centralize NATS subject definitions
|
||||
natsSubject := fmt.Sprintf("turbine.websocket.packet.%s", packet.Endpoint)
|
||||
err = m.natsClient.Publish(natsSubject, eventBytes)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to publish WebSocketPacketEvent to NATS")
|
||||
return err
|
||||
}
|
||||
log.Info().Msgf("Forwarded packet to NATS subject %s from device %s", natsSubject, deviceID)
|
||||
return nil
|
||||
}
|
||||
82
pkg/ring/websocket/models.go
Normal file
82
pkg/ring/websocket/models.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
pb "git.solsynth.dev/goatworks/turbine/pkg/shared/proto/gen"
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
)
|
||||
|
||||
// WebSocketPacket represents a WebSocket message packet.
|
||||
type WebSocketPacket struct {
|
||||
Type string `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"` // Use json.RawMessage to delay deserialization
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
}
|
||||
|
||||
// ToBytes serializes the WebSocketPacket to a byte array for sending over WebSocket.
|
||||
func (w *WebSocketPacket) ToBytes() ([]byte, error) {
|
||||
return json.Marshal(w)
|
||||
}
|
||||
|
||||
// FromBytes deserializes a byte array into a WebSocketPacket.
|
||||
func FromBytes(bytes []byte) (*WebSocketPacket, error) {
|
||||
var packet WebSocketPacket
|
||||
err := json.Unmarshal(bytes, &packet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to deserialize WebSocketPacket: %w", err)
|
||||
}
|
||||
return &packet, nil
|
||||
}
|
||||
|
||||
// GetData deserializes the Data property to the specified type T.
|
||||
func (w *WebSocketPacket) GetData(v interface{}) error {
|
||||
if w.Data == nil {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(w.Data, v)
|
||||
}
|
||||
|
||||
// ToProtoValue converts the WebSocketPacket to its protobuf equivalent.
|
||||
func (w *WebSocketPacket) ToProtoValue() *pb.WebSocketPacket {
|
||||
var dataBytes []byte
|
||||
if w.Data != nil {
|
||||
dataBytes = w.Data
|
||||
}
|
||||
var errorMessage string
|
||||
if w.ErrorMessage != "" {
|
||||
errorMessage = w.ErrorMessage
|
||||
}
|
||||
return &pb.WebSocketPacket{
|
||||
Type: w.Type,
|
||||
Data: dataBytes,
|
||||
ErrorMessage: wrapperspb.String(errorMessage),
|
||||
}
|
||||
}
|
||||
|
||||
// FromProtoValue converts a protobuf WebSocketPacket to its Go struct equivalent.
|
||||
func FromProtoValue(packet *pb.WebSocketPacket) *WebSocketPacket {
|
||||
var data json.RawMessage
|
||||
if packet.Data != nil {
|
||||
data = json.RawMessage(packet.Data)
|
||||
}
|
||||
var errorMessage string
|
||||
if packet.ErrorMessage != nil {
|
||||
errorMessage = packet.ErrorMessage.GetValue()
|
||||
}
|
||||
|
||||
return &WebSocketPacket{
|
||||
Type: packet.Type,
|
||||
Data: data,
|
||||
ErrorMessage: errorMessage,
|
||||
}
|
||||
}
|
||||
|
||||
// WebSocketPacketType constants from C# example
|
||||
const (
|
||||
WebSocketPacketTypePing = "ping"
|
||||
WebSocketPacketTypePong = "pong"
|
||||
WebSocketPacketTypeError = "error"
|
||||
)
|
||||
Reference in New Issue
Block a user