From 26da4c4e41fcd8886e87c0c404d019ee9e3215cd Mon Sep 17 00:00:00 2001 From: MFarrasHakim Date: Thu, 26 Oct 2023 08:02:46 +0700 Subject: [PATCH] feat(OpenAI): Custom API version and functionCall option --- example/evaluator/main.go | 2 +- example/functioncall/main.go | 2 +- example/simpleChatBot/main.go | 2 +- example/vector_conversation/main.go | 2 +- example/websocketSimpleChat/wssimplechat.go | 2 +- go.mod | 3 +++ go.sum | 7 +++++++ model/function.go | 7 +++++-- model/model.go | 7 +++++++ model/openAI/openaiChat.go | 8 +++++++- 10 files changed, 34 insertions(+), 8 deletions(-) diff --git a/example/evaluator/main.go b/example/evaluator/main.go index 221ebd4..897b16a 100644 --- a/example/evaluator/main.go +++ b/example/evaluator/main.go @@ -48,7 +48,7 @@ func main() { } var authToken = os.Getenv("OPENAI_API_KEY") - chatModel := _openai.NewOpenAIChatModel(authToken, "", "", _openai.GPT3Dot5Turbo0301, callback.NewManager(), false) + chatModel := _openai.NewOpenAIChatModel(authToken, "", "", "", _openai.GPT3Dot5Turbo0301, callback.NewManager(), false) testRunner(tests, chatModel) } diff --git a/example/functioncall/main.go b/example/functioncall/main.go index 350ef38..e89176b 100644 --- a/example/functioncall/main.go +++ b/example/functioncall/main.go @@ -15,7 +15,7 @@ import ( func main() { var authToken = os.Getenv("OPENAI_API_KEY") var chatModel model.ChatModel - chatModel = _openai.NewOpenAIChatModel(authToken, "", "", _openai.GPT3Dot5Turbo0301, callback.NewManager(), false) + chatModel = _openai.NewOpenAIChatModel(authToken, "", "", "", _openai.GPT3Dot5Turbo0301, callback.NewManager(), false) memory := []model.ChatMessage{} greeter := greeting.NewGreetingTool() diff --git a/example/simpleChatBot/main.go b/example/simpleChatBot/main.go index ec89115..e787a22 100644 --- a/example/simpleChatBot/main.go +++ b/example/simpleChatBot/main.go @@ -17,7 +17,7 @@ func main() { fmt.Println("Type .quit to exit") var authToken = os.Getenv("OPENAI_API_KEY") - chatModel := _openai.NewOpenAIChatModel(authToken, "", "", _openai.GPT3Dot5Turbo0301, callback.NewManager(), false) + chatModel := _openai.NewOpenAIChatModel(authToken, "", "", "", _openai.GPT3Dot5Turbo0301, callback.NewManager(), false) memory := []model.ChatMessage{} streamingChannel := make(chan model.ChatMessage, 100) convoChain := conversation.NewConversationChain(chatModel, memory, callback.NewManager(), "You're helpful chatbot that answer human question very concisely", false) diff --git a/example/vector_conversation/main.go b/example/vector_conversation/main.go index be42c12..9a73ced 100644 --- a/example/vector_conversation/main.go +++ b/example/vector_conversation/main.go @@ -45,7 +45,7 @@ func Init() (err error) { return } embeddingModel = _openai.NewOpenAIEmbedModel(OAIauthToken, "", "", openai.AdaEmbeddingV2) - chatModel = _openai.NewOpenAIChatModel(OAIauthToken, "", "", _openai.GPT3Dot5Turbo0301, callback.NewManager(), false) + chatModel = _openai.NewOpenAIChatModel(OAIauthToken, "", "", "", _openai.GPT3Dot5Turbo0301, callback.NewManager(), false) wvClient, err = weaviateVS.NewWeaviateVectorStore(wvhost, wvscheme, wvApiKey, embeddingModel, nil) if err != nil { diff --git a/example/websocketSimpleChat/wssimplechat.go b/example/websocketSimpleChat/wssimplechat.go index e9ea3fd..5831055 100644 --- a/example/websocketSimpleChat/wssimplechat.go +++ b/example/websocketSimpleChat/wssimplechat.go @@ -25,7 +25,7 @@ var authToken = os.Getenv("OPENAI_API_KEY") var chatModel *_openai.OpenAIChatModel func main() { - chatModel = _openai.NewOpenAIChatModel(authToken, "", "", _openai.GPT3Dot5Turbo0301, callback.NewManager(), false) + chatModel = _openai.NewOpenAIChatModel(authToken, "", "", "", _openai.GPT3Dot5Turbo0301, callback.NewManager(), false) fs := http.FileServer(http.Dir(".")) http.Handle("/", fs) diff --git a/go.mod b/go.mod index 06666b4..b7af66a 100644 --- a/go.mod +++ b/go.mod @@ -13,10 +13,13 @@ require ( ) require ( + github.com/alecthomas/jsonschema v0.0.0-20220216202328-9eeeec9d044b // indirect github.com/elastic/elastic-transport-go/v8 v8.3.0 // indirect + github.com/iancoleman/orderedmap v0.0.0-20190318233801-ac98e3ecb4b0 // indirect github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.16 // indirect + github.com/tfkhsr/jsonschema v0.0.0-20180218143334-273afdd5a88c // indirect golang.org/x/sys v0.6.0 // indirect ) diff --git a/go.sum b/go.sum index 72254a9..fd81f0e 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,8 @@ github.com/PuerkitoBio/purell v1.1.1 h1:WEQqlqaGbrPkxLJWfBwQmfEAE1Z7ONdDLqrN38tN github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 h1:d+Bc7a5rLufV/sSk/8dngufqelfh6jnri85riMAaF/M= github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= +github.com/alecthomas/jsonschema v0.0.0-20220216202328-9eeeec9d044b h1:doCpXjVwui6HUN+xgNsNS3SZ0/jUZ68Eb+mJRNOZfog= +github.com/alecthomas/jsonschema v0.0.0-20220216202328-9eeeec9d044b/go.mod h1:/n6+1/DWPltRLWL/VKyUxg6tzsl5kHUCcraimt4vr60= github.com/anaskhan96/soup v1.2.5 h1:V/FHiusdTrPrdF4iA1YkVxsOpdNcgvqT1hG+YtcZ5hM= github.com/anaskhan96/soup v1.2.5/go.mod h1:6YnEp9A2yywlYdM4EgDz9NEHclocMepEtku7wg6Cq3s= github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= @@ -83,6 +85,8 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/iancoleman/orderedmap v0.0.0-20190318233801-ac98e3ecb4b0 h1:i462o439ZjprVSFSZLZxcsoAe592sZB1rci2Z8j4wdk= +github.com/iancoleman/orderedmap v0.0.0-20190318233801-ac98e3ecb4b0/go.mod h1:N0Wam8K1arqPXNWjMo21EXnBPOPp36vB07FNRdD2geA= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= @@ -146,12 +150,15 @@ github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.3.1-0.20190311161405-34c6fa2dc709/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/tfkhsr/jsonschema v0.0.0-20180218143334-273afdd5a88c h1:FiJHojQ8AwCcltJnytC3Xkj37gW2WTzUzGl3AEYL+5U= +github.com/tfkhsr/jsonschema v0.0.0-20180218143334-273afdd5a88c/go.mod h1:zhGMpmE6P0Eml0MgFIc5TljSWlr/hbNSmig8KiVEodo= github.com/tidwall/pretty v1.0.0 h1:HsD+QiTn7sK6flMKIvNmpqz1qrpP3Ps6jOKIKMooyg4= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/trietmn/go-wiki v1.0.1 h1:OnKPSfE/XtWH9ybRxD7UcNv4bLzv8WcTWxwMcIDsFyg= diff --git a/model/function.go b/model/function.go index bd0dfb4..89f5d7d 100644 --- a/model/function.go +++ b/model/function.go @@ -5,8 +5,10 @@ import "strings" type DataType string const ( - FunctionDataTypeString DataType = "string" - FunctionDataTypeObject DataType = "object" + FunctionDataTypeBoolean DataType = "boolean" + FunctionDataTypeString DataType = "string" + FunctionDataTypeObject DataType = "object" + FunctionDataTypeArray DataType = "array" ) type FunctionJsonSchema struct { @@ -15,6 +17,7 @@ type FunctionJsonSchema struct { Required []string `json:"required,omitempty"` Description string `json:"description,omitempty"` Enum []string `json:"enum,omitempty"` + Items *FunctionJsonSchema `json:"items,omitempty"` } // FunctionDefinition is to describe function to model diff --git a/model/model.go b/model/model.go index 0f01808..99d6fd6 100644 --- a/model/model.go +++ b/model/model.go @@ -26,6 +26,7 @@ type Option struct { Temperature float32 StreamingChannel chan ChatMessage // non chat model can also use this Functions []FunctionDefinition + FunctionCall any AdditionalMetadataFields []string MaxToken int IsStreaming bool @@ -67,3 +68,9 @@ func WithFunctions(function []FunctionDefinition) func(*Option) { o.Functions = function } } + +func WithFunctionCall(functionCall any) func(*Option) { + return func(o *Option) { + o.FunctionCall = functionCall + } +} diff --git a/model/openAI/openaiChat.go b/model/openAI/openaiChat.go index 0da5ed8..b49f8d1 100644 --- a/model/openAI/openaiChat.go +++ b/model/openAI/openaiChat.go @@ -20,13 +20,16 @@ type OpenAIChatModel struct { } // NewOpenAIChatModel return new openAI Model instance -func NewOpenAIChatModel(authToken string, orgID string, baseURL string, modelName string, callbackManager *callback.Manager, verbose bool) (llm *OpenAIChatModel) { +func NewOpenAIChatModel(authToken string, orgID string, baseURL string, apiVersion string, modelName string, callbackManager *callback.Manager, verbose bool) (llm *OpenAIChatModel) { var client *goopenai.Client if baseURL == "" { client = goopenai.NewClient(authToken) } else { config := goopenai.DefaultAzureConfig(authToken, baseURL) config.OrgID = orgID + if apiVersion != "" { + config.APIVersion = apiVersion + } client = goopenai.NewClientWithConfig(config) } @@ -96,6 +99,9 @@ func (O *OpenAIChatModel) Chat(ctx context.Context, messages []model.ChatMessage Functions: RequestFunctions, Stream: false, } + if opts.FunctionCall != nil { + request.FunctionCall = opts.FunctionCall + } response, err := O.c.CreateChatCompletion(ctx, request) if err != nil {