diff --git a/pkg/internal/grpc/payment.go b/pkg/internal/grpc/payment.go index 42040f4..f3c7811 100644 --- a/pkg/internal/grpc/payment.go +++ b/pkg/internal/grpc/payment.go @@ -2,11 +2,11 @@ package grpc import ( "context" + "git.solsynth.dev/hypernet/wallet/pkg/internal/database" "git.solsynth.dev/hypernet/wallet/pkg/internal/models" + "git.solsynth.dev/hypernet/wallet/pkg/internal/services" "git.solsynth.dev/hypernet/wallet/pkg/proto" - "github.com/samber/lo" - "github.com/shopspring/decimal" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -38,18 +38,20 @@ func (v *Server) MakeTransaction(ctx context.Context, request *proto.MakeTransac return nil, status.Errorf(codes.InvalidArgument, "payer and payee cannot be both nil") } - transaction := models.Transaction{ - Amount: decimal.NewFromFloat(request.Amount), - Remark: request.Remark, - } + var payerWallet, payeeWallet *models.Wallet if request.PayerId != nil { - transaction.PayerID = lo.ToPtr(uint(*request.PayerId)) + if err := database.C.Where("id = ?", request.PayerId).First(&payerWallet); err != nil { + return nil, status.Errorf(codes.NotFound, "payer wallet not found: %v", err) + } } if request.PayeeId != nil { - transaction.PayeeID = lo.ToPtr(uint(*request.PayeeId)) + if err := database.C.Where("id = ?", request.PayeeId).First(&payeeWallet); err != nil { + return nil, status.Errorf(codes.NotFound, "payee wallet not found: %v", err) + } } - if err := database.C.Create(&transaction).Error; err != nil { + transaction, err := services.MakeTransaction(request.GetAmount(), request.GetRemark(), payerWallet, payeeWallet) + if err != nil { return nil, status.Errorf(codes.Internal, err.Error()) } @@ -61,30 +63,22 @@ func (v *Server) MakeTransactionWithAccount(ctx context.Context, request *proto. return nil, status.Errorf(codes.InvalidArgument, "payer and payee cannot be both nil") } - transaction := models.Transaction{ - Amount: decimal.NewFromFloat(request.Amount), - Remark: request.Remark, - } + var payerWallet, payeeWallet *models.Wallet if request.PayerAccountId != nil { val := uint(*request.PayerAccountId) - var wallet models.Wallet - if err := database.C.Where("account_id = ?", val).First(&wallet).Error; err != nil { - return nil, status.Errorf(codes.NotFound, "payer wallet not found") + if err := database.C.Where("account_id = ?", val).First(&payerWallet).Error; err != nil { + return nil, status.Errorf(codes.NotFound, "payer wallet not found: %v", err) } - transaction.Payer = &wallet - transaction.PayerID = &wallet.ID } if request.PayeeAccountId != nil { val := uint(*request.PayeeAccountId) - var wallet models.Wallet - if err := database.C.Where("account_id = ?", val).First(&wallet).Error; err != nil { - return nil, status.Errorf(codes.NotFound, "payee wallet not found") + if err := database.C.Where("account_id = ?", val).First(&payeeWallet).Error; err != nil { + return nil, status.Errorf(codes.NotFound, "payee wallet not found: %v", err) } - transaction.Payee = &wallet - transaction.PayeeID = &wallet.ID } - if err := database.C.Create(&transaction).Error; err != nil { + transaction, err := services.MakeTransaction(request.GetAmount(), request.GetRemark(), payerWallet, payeeWallet) + if err != nil { return nil, status.Errorf(codes.Internal, err.Error()) } diff --git a/pkg/internal/services/payment.go b/pkg/internal/services/payment.go new file mode 100644 index 0000000..697cda0 --- /dev/null +++ b/pkg/internal/services/payment.go @@ -0,0 +1,50 @@ +package services + +import ( + "fmt" + + "git.solsynth.dev/hypernet/wallet/pkg/internal/database" + "git.solsynth.dev/hypernet/wallet/pkg/internal/models" + "github.com/shopspring/decimal" +) + +func MakeTransaction(amount float64, remark string, payer, payee *models.Wallet) (models.Transaction, error) { + transaction := models.Transaction{ + Amount: decimal.NewFromFloat(amount), + Remark: remark, + } + if payer != nil { + transaction.PayerID = &payer.ID + } + if payee != nil { + transaction.PayeeID = &payee.ID + } + + tx := database.C.Begin() + + if err := tx.Create(&transaction).Error; err != nil { + tx.Rollback() + return transaction, err + } + + if payer != nil { + payer.Balance = payer.Balance.Sub(transaction.Amount) + if err := tx.Model(payer). + Updates(&models.Wallet{Balance: payer.Balance}).Error; err != nil { + tx.Rollback() + return transaction, fmt.Errorf("failed to update payer wallet balance: %w", err) + } + } + if payee != nil { + payee.Balance = payee.Balance.Add(transaction.Amount) + if err := tx.Model(payee). + Updates(&models.Wallet{Balance: payee.Balance}).Error; err != nil { + tx.Rollback() + return transaction, fmt.Errorf("failed to update payee wallet balance: %w", err) + } + } + + tx.Commit() + + return transaction, nil +}