From b2d7036b962bc81ed7ba12d48d719e386ee57e8f Mon Sep 17 00:00:00 2001 From: Muninn Date: Sun, 17 Nov 2024 03:41:58 +0700 Subject: [PATCH] feat: backend, get wallet --- backend/api.go | 1 + backend/handler/vapi_function.go | 16 ++++++++++++++++ backend/service/service.go | 9 ++++++--- backend/service/vapi_function.go | 24 +++++++++++++++++++++++- 4 files changed, 46 insertions(+), 4 deletions(-) diff --git a/backend/api.go b/backend/api.go index 4b74871..2781d88 100644 --- a/backend/api.go +++ b/backend/api.go @@ -72,6 +72,7 @@ func Start(ctx context.Context) error { app.Post("/function/trade", h.VAPIFunctionTrade) app.Post("/function/confirm", h.VAPIFunctionConfirm) app.Post("/function/sign", h.VAPIFunctionSign) + app.Post("/function/get-wallet", h.VAPIFunctionGetWallet) app.Get("/chats/:id/state", h.GetChatState) diff --git a/backend/handler/vapi_function.go b/backend/handler/vapi_function.go index f33da9a..9285a65 100644 --- a/backend/handler/vapi_function.go +++ b/backend/handler/vapi_function.go @@ -65,3 +65,19 @@ func (h *Handler) VAPIFunctionSign(c *fiber.Ctx) error { return c.Status(fiber.StatusOK).JSON(resp) } + +func (h *Handler) VAPIFunctionGetWallet(c *fiber.Ctx) error { + var msg = new(types.VapiServerMessageToolCall) + err := c.BodyParser(msg) + if err != nil { + h.log.Error("VAPIFunctionGetWallet", "error", err) + h.log.Error("VAPIFunctionGetWallet body", "body", string(c.Body())) + return types.NewError(fiber.StatusBadRequest, "BadRequest", err.Error()) + } + resp, err := h.s.VAPIFunctionGetWallet(c.Context(), msg) + if err != nil { + return err + } + + return c.Status(fiber.StatusOK).JSON(resp) +} diff --git a/backend/service/service.go b/backend/service/service.go index f6fc0a7..c040e97 100644 --- a/backend/service/service.go +++ b/backend/service/service.go @@ -15,9 +15,11 @@ type Options struct { } type Service struct { - config *types.Config - log *slog.Logger - state sync.Map // call id -> state + config *types.Config + log *slog.Logger + state sync.Map // call id -> state + wallet string + balance sync.Map // symbol -> balance } func New(options Options) (*Service, error) { @@ -33,5 +35,6 @@ func New(options Options) (*Service, error) { return &Service{ config: options.Config, log: log, + wallet: "0x1234567890", }, nil } diff --git a/backend/service/vapi_function.go b/backend/service/vapi_function.go index 6f57519..b53acbf 100644 --- a/backend/service/vapi_function.go +++ b/backend/service/vapi_function.go @@ -28,7 +28,11 @@ func (s *Service) VAPIFunctionTrade(ctx context.Context, msg *types.VapiServerMe if !ok { return nil, types.NewError(fiber.StatusBadRequest, "BadRequest", "origin_token_amount is required") } - trade.OriginTokenAmount = amount.(float64) + trade.OriginTokenAmount, ok = amount.(float64) + if !ok { + s.log.Error("VAPIFunctionTrade", "error", "origin_token_amount is not float64") + return vapiToolResponse(id, "error, origin_token_amount is not float64"), nil + } action, ok := tool.Function.Arguments["destination_token_symbol"] if !ok { return nil, types.NewError(fiber.StatusBadRequest, "BadRequest", "destination_token_symbol is required") @@ -157,6 +161,24 @@ func (s *Service) VAPIFunctionSign(ctx context.Context, msg *types.VapiServerMes return resp, nil } +func (s *Service) VAPIFunctionGetWallet(ctx context.Context, msg *types.VapiServerMessageToolCall) (*types.ToolResults, error) { + var id string + var resp = new(types.ToolResults) + resp.Results = make([]types.ToolResult, 0) + for _, tool := range msg.Message.ToolCallList { + if tool.Function != nil { + if tool.Function.Name == "get_wallet_address" { + id = tool.Id + break + } + } + } + + s.log.Warn("VAPIFunction get wallet called") + vapiToolResponse(id, fmt.Sprintf("Your wallet is %s", s.wallet)) + return resp, nil +} + func (s *Service) VAPIFunction(ctx context.Context, genericMessage map[string]interface{}) error { // req, _ := json.Marshal(genericMessage) // s.log.Info("VAPIFunctionTrade called", "message", req)