Provide client id in stream push request

This commit is contained in:
2024-08-23 19:08:07 +08:00
parent 4f93800f72
commit df5676cbe4
5 changed files with 185 additions and 87 deletions

View File

@ -17,7 +17,17 @@ func (v *Server) CountStreamConnection(ctx context.Context, request *proto.Count
}
func (v *Server) PushStream(ctx context.Context, request *proto.PushStreamRequest) (*proto.PushStreamResponse, error) {
cnt, success, errs := services.WebsocketPush(uint(request.GetUserId()), request.GetBody())
var cnt int
var success int
var errs []error
if request.UserId != nil {
cnt, success, errs = services.WebsocketPush(uint(request.GetUserId()), request.GetBody())
} else if request.ClientId != nil {
cnt, success, errs = services.WebsocketPushDirect(request.GetClientId(), request.GetBody())
} else {
return nil, fmt.Errorf("you must give one of the user id or client id")
}
if len(errs) > 0 {
// Partial fail
return &proto.PushStreamResponse{
@ -38,12 +48,24 @@ func (v *Server) PushStream(ctx context.Context, request *proto.PushStreamReques
}
func (v *Server) PushStreamBatch(ctx context.Context, request *proto.PushStreamBatchRequest) (*proto.PushStreamResponse, error) {
cnt, success, errs := services.WebsocketPushBatch(
lo.Map(request.GetUserId(), func(item uint64, idx int) uint {
return uint(item)
},
), request.GetBody(),
)
var cnt int
var success int
var errs []error
if len(request.UserId) != 0 {
cnt, success, errs = services.WebsocketPushBatch(
lo.Map(request.GetUserId(), func(item uint64, idx int) uint {
return uint(item)
},
), request.GetBody(),
)
}
if len(request.ClientId) != 0 {
cCnt, cSuccess, cErrs := services.WebsocketPushBatchDirect(request.GetClientId(), request.GetBody())
cnt += cCnt
success += cSuccess
errs = append(errs, cErrs...)
}
if len(errs) > 0 {
// Partial fail
return &proto.PushStreamResponse{

View File

@ -3,6 +3,7 @@ package api
import (
"context"
"fmt"
"git.solsynth.dev/hydrogen/dealer/pkg/hyper"
"git.solsynth.dev/hydrogen/dealer/pkg/internal/directory"
"git.solsynth.dev/hydrogen/dealer/pkg/internal/models"
@ -18,8 +19,11 @@ func listenWebsocket(c *websocket.Conn) {
user := c.Locals("user").(models.Account)
// Push connection
services.ClientRegister(user, c)
log.Debug().Uint("user", user.ID).Msg("New websocket connection established...")
clientId := services.ClientRegister(user, c)
log.Debug().
Uint("user", user.ID).
Uint64("clientId", clientId).
Msg("New websocket connection established...")
// Event loop
var mt int
@ -63,11 +67,11 @@ func listenWebsocket(c *websocket.Conn) {
sc := proto.NewStreamControllerClient(pc)
_, err = sc.EmitStreamEvent(context.Background(), &proto.StreamEventRequest{
Event: packet.Action,
UserId: uint64(user.ID),
Payload: packet.RawPayload(),
Event: packet.Action,
UserId: uint64(user.ID),
ClientId: uint64(clientId),
Payload: packet.RawPayload(),
})
if err != nil {
_ = c.WriteMessage(mt, hyper.NetworkPackage{
Action: "error",
@ -78,6 +82,9 @@ func listenWebsocket(c *websocket.Conn) {
}
// Pop connection
services.ClientUnregister(user, c)
log.Debug().Uint("user", user.ID).Msg("A websocket connection disconnected...")
services.ClientUnregister(user, clientId)
log.Debug().
Uint("user", user.ID).
Uint64("clientId", clientId).
Msg("A websocket connection disconnected...")
}

View File

@ -2,6 +2,7 @@ package services
import (
"context"
"math/rand"
"sync"
"git.solsynth.dev/hydrogen/dealer/pkg/hyper"
@ -13,15 +14,16 @@ import (
var (
wsMutex sync.Mutex
wsConn = make(map[uint]map[*websocket.Conn]bool)
wsConn = make(map[uint]map[uint64]*websocket.Conn)
)
func ClientRegister(user models.Account, conn *websocket.Conn) {
func ClientRegister(user models.Account, conn *websocket.Conn) uint64 {
wsMutex.Lock()
if wsConn[user.ID] == nil {
wsConn[user.ID] = make(map[*websocket.Conn]bool)
wsConn[user.ID] = make(map[uint64]*websocket.Conn)
}
wsConn[user.ID][conn] = true
clientId := rand.Uint64()
wsConn[user.ID][clientId] = conn
wsMutex.Unlock()
pc, err := directory.GetServiceInstanceByType(hyper.ServiceTypeAuthProvider).GetGrpcConn()
@ -31,14 +33,16 @@ func ClientRegister(user models.Account, conn *websocket.Conn) {
UserId: uint64(user.ID),
})
}
return clientId
}
func ClientUnregister(user models.Account, conn *websocket.Conn) {
func ClientUnregister(user models.Account, id uint64) {
wsMutex.Lock()
if wsConn[user.ID] == nil {
wsConn[user.ID] = make(map[*websocket.Conn]bool)
wsConn[user.ID] = make(map[uint64]*websocket.Conn)
}
delete(wsConn[user.ID], conn)
delete(wsConn[user.ID], id)
wsMutex.Unlock()
pc, err := directory.GetServiceInstanceByType(hyper.ServiceTypeAuthProvider).GetGrpcConn()
@ -55,7 +59,7 @@ func ClientCount(uid uint) int {
}
func WebsocketPush(uid uint, body []byte) (count int, success int, errs []error) {
for conn := range wsConn[uid] {
for _, conn := range wsConn[uid] {
if err := conn.WriteMessage(1, body); err != nil {
errs = append(errs, err)
} else {
@ -66,9 +70,9 @@ func WebsocketPush(uid uint, body []byte) (count int, success int, errs []error)
return
}
func WebsocketPushBatch(uidList []uint, body []byte) (count int, success int, errs []error) {
for _, uid := range uidList {
for conn := range wsConn[uid] {
func WebsocketPushDirect(clientId uint64, body []byte) (count int, success int, errs []error) {
for _, m := range wsConn {
if conn, ok := m[clientId]; ok {
if err := conn.WriteMessage(1, body); err != nil {
errs = append(errs, err)
} else {
@ -79,3 +83,33 @@ func WebsocketPushBatch(uidList []uint, body []byte) (count int, success int, er
}
return
}
func WebsocketPushBatch(uidList []uint, body []byte) (count int, success int, errs []error) {
for _, uid := range uidList {
for _, conn := range wsConn[uid] {
if err := conn.WriteMessage(1, body); err != nil {
errs = append(errs, err)
} else {
success++
}
count++
}
}
return
}
func WebsocketPushBatchDirect(clientIdList []uint64, body []byte) (count int, success int, errs []error) {
for _, clientId := range clientIdList {
for _, m := range wsConn {
if conn, ok := m[clientId]; ok {
if err := conn.WriteMessage(1, body); err != nil {
errs = append(errs, err)
} else {
success++
}
count++
}
}
}
return
}