diff --git a/pkg/internal/grpc/server.go b/pkg/internal/grpc/server.go index 7fadaaf..fccc058 100644 --- a/pkg/internal/grpc/server.go +++ b/pkg/internal/grpc/server.go @@ -14,6 +14,7 @@ import ( type Server struct { proto.UnimplementedServiceDirectoryServer + proto.UnimplementedStreamControllerServer proto.UnimplementedAuthServer srv *grpc.Server @@ -24,9 +25,10 @@ func NewServer() *Server { srv: grpc.NewServer(), } - proto.RegisterServiceDirectoryServer(server.srv, &Server{}) - proto.RegisterAuthServer(server.srv, &Server{}) - health.RegisterHealthServer(server.srv, &Server{}) + proto.RegisterServiceDirectoryServer(server.srv, server) + proto.RegisterStreamControllerServer(server.srv, server) + proto.RegisterAuthServer(server.srv, server) + health.RegisterHealthServer(server.srv, server) reflection.Register(server.srv) diff --git a/pkg/internal/grpc/stream.go b/pkg/internal/grpc/stream.go new file mode 100644 index 0000000..bece984 --- /dev/null +++ b/pkg/internal/grpc/stream.go @@ -0,0 +1,29 @@ +package grpc + +import ( + "context" + "fmt" + "git.solsynth.dev/hydrogen/dealer/pkg/internal/services" + "git.solsynth.dev/hydrogen/dealer/pkg/proto" +) + +func (v *Server) PushStream(ctx context.Context, request *proto.PushStreamRequest) (*proto.PushStreamResponse, error) { + cnt, success, errs := services.WebsocketPush(uint(request.GetUserId()), request.GetBody()) + if len(errs) > 0 { + // Partial fail + return &proto.PushStreamResponse{ + IsAllSuccess: false, + AffectedCount: int64(success), + FailedCount: int64(cnt - success), + }, nil + } else if cnt > 0 && success == 0 { + // All fail + return nil, fmt.Errorf("all push request failed: %v", errs) + } + + return &proto.PushStreamResponse{ + IsAllSuccess: true, + AffectedCount: int64(success), + FailedCount: int64(cnt - success), + }, nil +} diff --git a/pkg/internal/services/connections.go b/pkg/internal/services/connections.go index 779def3..b9c2d87 100644 --- a/pkg/internal/services/connections.go +++ b/pkg/internal/services/connections.go @@ -28,3 +28,15 @@ func ClientUnregister(user models.Account, conn *websocket.Conn) { delete(wsConn[user.ID], conn) wsMutex.Unlock() } + +func WebsocketPush(uid uint, body []byte) (count int, success int, errs []error) { + for conn := range wsConn[uid] { + if err := conn.WriteMessage(1, body); err != nil { + errs = append(errs, err) + } else { + success++ + } + count++ + } + return +} diff --git a/pkg/proto/auth.pb.go b/pkg/proto/auth.pb.go index 7840c10..deeef37 100644 --- a/pkg/proto/auth.pb.go +++ b/pkg/proto/auth.pb.go @@ -504,9 +504,9 @@ var file_auth_proto_depIdxs = []int32{ 0, // 0: proto.AuthInfo.info:type_name -> proto.UserInfo 1, // 1: proto.AuthReply.info:type_name -> proto.AuthInfo 2, // 2: proto.Auth.Authenticate:input_type -> proto.AuthRequest - 4, // 3: proto.Auth.CheckPermGranted:input_type -> proto.CheckPermRequest + 4, // 3: proto.Auth.EnsurePermGranted:input_type -> proto.CheckPermRequest 3, // 4: proto.Auth.Authenticate:output_type -> proto.AuthReply - 5, // 5: proto.Auth.CheckPermGranted:output_type -> proto.CheckPermReply + 5, // 5: proto.Auth.EnsurePermGranted:output_type -> proto.CheckPermReply 4, // [4:6] is the sub-list for method output_type 2, // [2:4] is the sub-list for method input_type 2, // [2:2] is the sub-list for extension type_name diff --git a/pkg/proto/auth_grpc.pb.go b/pkg/proto/auth_grpc.pb.go index e87f652..b9812e5 100644 --- a/pkg/proto/auth_grpc.pb.go +++ b/pkg/proto/auth_grpc.pb.go @@ -20,7 +20,7 @@ const _ = grpc.SupportPackageIsVersion8 const ( Auth_Authenticate_FullMethodName = "/proto.Auth/Authenticate" - Auth_EnsurePermGranted_FullMethodName = "/proto.Auth/CheckPermGranted" + Auth_EnsurePermGranted_FullMethodName = "/proto.Auth/EnsurePermGranted" ) // AuthClient is the client API for Auth service. @@ -76,7 +76,7 @@ func (UnimplementedAuthServer) Authenticate(context.Context, *AuthRequest) (*Aut return nil, status.Errorf(codes.Unimplemented, "method Authenticate not implemented") } func (UnimplementedAuthServer) EnsurePermGranted(context.Context, *CheckPermRequest) (*CheckPermReply, error) { - return nil, status.Errorf(codes.Unimplemented, "method CheckPermGranted not implemented") + return nil, status.Errorf(codes.Unimplemented, "method EnsurePermGranted not implemented") } func (UnimplementedAuthServer) mustEmbedUnimplementedAuthServer() {} @@ -139,7 +139,7 @@ var Auth_ServiceDesc = grpc.ServiceDesc{ Handler: _Auth_Authenticate_Handler, }, { - MethodName: "CheckPermGranted", + MethodName: "EnsurePermGranted", Handler: _Auth_EnsurePermGranted_Handler, }, }, diff --git a/pkg/proto/stream.pb.go b/pkg/proto/stream.pb.go new file mode 100644 index 0000000..da10e1f --- /dev/null +++ b/pkg/proto/stream.pb.go @@ -0,0 +1,243 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.34.2 +// protoc v5.27.1 +// source: stream.proto + +package proto + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type PushStreamRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + UserId uint64 `protobuf:"varint,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` + Body []byte `protobuf:"bytes,2,opt,name=body,proto3" json:"body,omitempty"` +} + +func (x *PushStreamRequest) Reset() { + *x = PushStreamRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_stream_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PushStreamRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PushStreamRequest) ProtoMessage() {} + +func (x *PushStreamRequest) ProtoReflect() protoreflect.Message { + mi := &file_stream_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PushStreamRequest.ProtoReflect.Descriptor instead. +func (*PushStreamRequest) Descriptor() ([]byte, []int) { + return file_stream_proto_rawDescGZIP(), []int{0} +} + +func (x *PushStreamRequest) GetUserId() uint64 { + if x != nil { + return x.UserId + } + return 0 +} + +func (x *PushStreamRequest) GetBody() []byte { + if x != nil { + return x.Body + } + return nil +} + +type PushStreamResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + IsAllSuccess bool `protobuf:"varint,1,opt,name=is_all_success,json=isAllSuccess,proto3" json:"is_all_success,omitempty"` + AffectedCount int64 `protobuf:"varint,2,opt,name=affected_count,json=affectedCount,proto3" json:"affected_count,omitempty"` + FailedCount int64 `protobuf:"varint,3,opt,name=failed_count,json=failedCount,proto3" json:"failed_count,omitempty"` +} + +func (x *PushStreamResponse) Reset() { + *x = PushStreamResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_stream_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PushStreamResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PushStreamResponse) ProtoMessage() {} + +func (x *PushStreamResponse) ProtoReflect() protoreflect.Message { + mi := &file_stream_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PushStreamResponse.ProtoReflect.Descriptor instead. +func (*PushStreamResponse) Descriptor() ([]byte, []int) { + return file_stream_proto_rawDescGZIP(), []int{1} +} + +func (x *PushStreamResponse) GetIsAllSuccess() bool { + if x != nil { + return x.IsAllSuccess + } + return false +} + +func (x *PushStreamResponse) GetAffectedCount() int64 { + if x != nil { + return x.AffectedCount + } + return 0 +} + +func (x *PushStreamResponse) GetFailedCount() int64 { + if x != nil { + return x.FailedCount + } + return 0 +} + +var File_stream_proto protoreflect.FileDescriptor + +var file_stream_proto_rawDesc = []byte{ + 0x0a, 0x0c, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x40, 0x0a, 0x11, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x72, + 0x65, 0x61, 0x6d, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, + 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x75, 0x73, 0x65, + 0x72, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0c, 0x52, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0x84, 0x01, 0x0a, 0x12, 0x50, 0x75, 0x73, 0x68, + 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, + 0x0a, 0x0e, 0x69, 0x73, 0x5f, 0x61, 0x6c, 0x6c, 0x5f, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c, 0x69, 0x73, 0x41, 0x6c, 0x6c, 0x53, 0x75, 0x63, + 0x63, 0x65, 0x73, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x61, 0x66, 0x66, 0x65, 0x63, 0x74, 0x65, 0x64, + 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x61, 0x66, + 0x66, 0x65, 0x63, 0x74, 0x65, 0x64, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x21, 0x0a, 0x0c, 0x66, + 0x61, 0x69, 0x6c, 0x65, 0x64, 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x03, 0x52, 0x0b, 0x66, 0x61, 0x69, 0x6c, 0x65, 0x64, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x32, 0x57, + 0x0a, 0x10, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x43, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, + 0x65, 0x72, 0x12, 0x43, 0x0a, 0x0a, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, + 0x12, 0x18, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x72, + 0x65, 0x61, 0x6d, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x3b, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_stream_proto_rawDescOnce sync.Once + file_stream_proto_rawDescData = file_stream_proto_rawDesc +) + +func file_stream_proto_rawDescGZIP() []byte { + file_stream_proto_rawDescOnce.Do(func() { + file_stream_proto_rawDescData = protoimpl.X.CompressGZIP(file_stream_proto_rawDescData) + }) + return file_stream_proto_rawDescData +} + +var file_stream_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_stream_proto_goTypes = []any{ + (*PushStreamRequest)(nil), // 0: proto.PushStreamRequest + (*PushStreamResponse)(nil), // 1: proto.PushStreamResponse +} +var file_stream_proto_depIdxs = []int32{ + 0, // 0: proto.StreamController.PushStream:input_type -> proto.PushStreamRequest + 1, // 1: proto.StreamController.PushStream:output_type -> proto.PushStreamResponse + 1, // [1:2] is the sub-list for method output_type + 0, // [0:1] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_stream_proto_init() } +func file_stream_proto_init() { + if File_stream_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_stream_proto_msgTypes[0].Exporter = func(v any, i int) any { + switch v := v.(*PushStreamRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_stream_proto_msgTypes[1].Exporter = func(v any, i int) any { + switch v := v.(*PushStreamResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_stream_proto_rawDesc, + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_stream_proto_goTypes, + DependencyIndexes: file_stream_proto_depIdxs, + MessageInfos: file_stream_proto_msgTypes, + }.Build() + File_stream_proto = out.File + file_stream_proto_rawDesc = nil + file_stream_proto_goTypes = nil + file_stream_proto_depIdxs = nil +} diff --git a/pkg/proto/stream.proto b/pkg/proto/stream.proto new file mode 100644 index 0000000..7e05d45 --- /dev/null +++ b/pkg/proto/stream.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +option go_package = ".;proto"; + +package proto; + +service StreamController { + rpc PushStream(PushStreamRequest) returns (PushStreamResponse) {} +} + +message PushStreamRequest { + uint64 user_id = 1; + bytes body = 2; +} + +message PushStreamResponse { + bool is_all_success = 1; + int64 affected_count = 2; + int64 failed_count = 3; +} \ No newline at end of file diff --git a/pkg/proto/stream_grpc.pb.go b/pkg/proto/stream_grpc.pb.go new file mode 100644 index 0000000..b0100f9 --- /dev/null +++ b/pkg/proto/stream_grpc.pb.go @@ -0,0 +1,110 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.4.0 +// - protoc v5.27.1 +// source: stream.proto + +package proto + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.62.0 or later. +const _ = grpc.SupportPackageIsVersion8 + +const ( + StreamController_PushStream_FullMethodName = "/proto.StreamController/PushStream" +) + +// StreamControllerClient is the client API for StreamController service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type StreamControllerClient interface { + PushStream(ctx context.Context, in *PushStreamRequest, opts ...grpc.CallOption) (*PushStreamResponse, error) +} + +type streamControllerClient struct { + cc grpc.ClientConnInterface +} + +func NewStreamControllerClient(cc grpc.ClientConnInterface) StreamControllerClient { + return &streamControllerClient{cc} +} + +func (c *streamControllerClient) PushStream(ctx context.Context, in *PushStreamRequest, opts ...grpc.CallOption) (*PushStreamResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(PushStreamResponse) + err := c.cc.Invoke(ctx, StreamController_PushStream_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// StreamControllerServer is the server API for StreamController service. +// All implementations must embed UnimplementedStreamControllerServer +// for forward compatibility +type StreamControllerServer interface { + PushStream(context.Context, *PushStreamRequest) (*PushStreamResponse, error) + mustEmbedUnimplementedStreamControllerServer() +} + +// UnimplementedStreamControllerServer must be embedded to have forward compatible implementations. +type UnimplementedStreamControllerServer struct { +} + +func (UnimplementedStreamControllerServer) PushStream(context.Context, *PushStreamRequest) (*PushStreamResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method PushStream not implemented") +} +func (UnimplementedStreamControllerServer) mustEmbedUnimplementedStreamControllerServer() {} + +// UnsafeStreamControllerServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to StreamControllerServer will +// result in compilation errors. +type UnsafeStreamControllerServer interface { + mustEmbedUnimplementedStreamControllerServer() +} + +func RegisterStreamControllerServer(s grpc.ServiceRegistrar, srv StreamControllerServer) { + s.RegisterService(&StreamController_ServiceDesc, srv) +} + +func _StreamController_PushStream_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(PushStreamRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(StreamControllerServer).PushStream(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: StreamController_PushStream_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(StreamControllerServer).PushStream(ctx, req.(*PushStreamRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// StreamController_ServiceDesc is the grpc.ServiceDesc for StreamController service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var StreamController_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "proto.StreamController", + HandlerType: (*StreamControllerServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "PushStream", + Handler: _StreamController_PushStream_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "stream.proto", +}