Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: String() is unsafe, so rename to Print() and make a safer version #954

Merged
merged 2 commits into from
Mar 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions pkg/assemble/assemble.go

This file was deleted.

2 changes: 1 addition & 1 deletion pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func Authorize(ctx engine.Context, input string) (runner.AuthorizerResponse, err

var result bool
err := survey.AskOne(&survey.Confirm{
Help: fmt.Sprintf("The full source of the tools is as follows:\n\n%s", ctx.Tool.String()),
Help: fmt.Sprintf("The full source of the tools is as follows:\n\n%s", ctx.Tool.Print()),
Default: true,
Message: ConfirmMessage(ctx, input),
}, &result)
Expand Down
4 changes: 2 additions & 2 deletions pkg/cli/fmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ func (e *Fmt) Run(_ *cobra.Command, args []string) error {
}

if e.Write && loc != "" {
return os.WriteFile(loc, []byte(doc.String()), 0644)
return os.WriteFile(loc, []byte(doc.Print()), 0644)
}

fmt.Print(doc.String())
fmt.Print(doc.Print())
return nil
}
16 changes: 0 additions & 16 deletions pkg/cli/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/fatih/color"
"github.com/gptscript-ai/cmd"
gptscript2 "github.com/gptscript-ai/go-gptscript"
"github.com/gptscript-ai/gptscript/pkg/assemble"
"github.com/gptscript-ai/gptscript/pkg/auth"
"github.com/gptscript-ai/gptscript/pkg/builtin"
"github.com/gptscript-ai/gptscript/pkg/cache"
Expand Down Expand Up @@ -58,7 +57,6 @@ type GPTScript struct {
// Input should not be using GPTSCRIPT_INPUT env var because that is the same value that is set in tool executions
Input string `usage:"Read input from a file (\"-\" for stdin)" short:"f" env:"GPTSCRIPT_INPUT_FILE"`
SubTool string `usage:"Use tool of this name, not the first tool in file" local:"true"`
Assemble bool `usage:"Assemble tool to a single artifact, saved to --output" hidden:"true" local:"true"`
ListModels bool `usage:"List the models available and exit" local:"true"`
ListTools bool `usage:"List built-in tools and exit" local:"true"`
ListenAddress string `usage:"Server listen address" default:"127.0.0.1:0" hidden:"true"`
Expand Down Expand Up @@ -439,20 +437,6 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) {
return cmd.Help()
}

if r.Assemble {
var out io.Writer = os.Stdout
if r.Output != "" && r.Output != "-" {
f, err := os.Create(r.Output)
if err != nil {
return fmt.Errorf("opening %s: %w", r.Output, err)
}
defer f.Close()
out = f
}

return assemble.Assemble(prg, out)
}

toolInput, err := input.FromCLI(r.Input, args)
if err != nil {
return err
Expand Down
57 changes: 13 additions & 44 deletions pkg/loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (

"github.com/getkin/kin-openapi/openapi3"
"github.com/gptscript-ai/gptscript/internal"
"github.com/gptscript-ai/gptscript/pkg/assemble"
"github.com/gptscript-ai/gptscript/pkg/builtin"
"github.com/gptscript-ai/gptscript/pkg/cache"
"github.com/gptscript-ai/gptscript/pkg/hash"
Expand Down Expand Up @@ -132,36 +131,6 @@ func loadLocal(base *source, name string) (*source, bool, error) {
}, true, nil
}

func loadProgram(data []byte, into *types.Program, targetToolName, defaultModel string) (types.Tool, error) {
var ext types.Program

if err := json.Unmarshal(data[len(assemble.Header):], &ext); err != nil {
return types.Tool{}, err
}

into.ToolSet = make(map[string]types.Tool, len(ext.ToolSet))
for k, v := range ext.ToolSet {
if builtinTool, ok := builtin.DefaultModel(k, defaultModel); ok {
v = builtinTool
}
into.ToolSet[k] = v
}

tool := into.ToolSet[ext.EntryToolID]
if targetToolName == "" {
return tool, nil
}

tool, ok := into.ToolSet[tool.LocalTools[strings.ToLower(targetToolName)]]
if !ok {
return tool, &types.ErrToolNotFound{
ToolName: targetToolName,
}
}

return tool, nil
}

func loadOpenAPI(prg *types.Program, data []byte) *openapi3.T {
var (
openAPICacheKey = hash.Digest(data)
Expand Down Expand Up @@ -189,14 +158,6 @@ func loadOpenAPI(prg *types.Program, data []byte) *openapi3.T {
func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, targetToolName, defaultModel string) ([]types.Tool, error) {
data := base.Content

if bytes.HasPrefix(data, assemble.Header) {
tool, err := loadProgram(data, prg, targetToolName, defaultModel)
if err != nil {
return nil, err
}
return []types.Tool{tool}, nil
}

var (
tools []types.Tool
isOpenAPI bool
Expand Down Expand Up @@ -231,11 +192,19 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base
// If we didn't get any tools from trying to parse it as OpenAPI, try to parse it as a GPTScript
if len(tools) == 0 {
var err error
tools, err = parser.ParseTools(bytes.NewReader(data), parser.Options{
AssignGlobals: true,
})
if err != nil {
return nil, err
_, marshaled, ok := strings.Cut(string(data), "#!GPTSCRIPT")
if ok {
err = json.Unmarshal([]byte(marshaled), &tools)
if err != nil {
return nil, fmt.Errorf("error parsing marshalled script: %w", err)
}
} else {
tools, err = parser.ParseTools(bytes.NewReader(data), parser.Options{
AssignGlobals: true,
})
if err != nil {
return nil, err
}
}
}

Expand Down
34 changes: 30 additions & 4 deletions pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,7 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
chatMessage.ToolCalls = append(chatMessage.ToolCalls, toToolCall(*content.ToolCall))
}
if content.Text != "" {
chatMessage.MultiContent = append(chatMessage.MultiContent, openai.ChatMessagePart{
Type: openai.ChatMessagePartTypeText,
Text: content.Text,
})
chatMessage.MultiContent = append(chatMessage.MultiContent, textToMultiContent(content.Text)...)
}
}

Expand All @@ -307,6 +304,35 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
return
}

const imagePrefix = "data:image/png;base64,"

func textToMultiContent(text string) []openai.ChatMessagePart {
var chatParts []openai.ChatMessagePart
parts := strings.Split(text, "\n")
for i := len(parts) - 1; i >= 0; i-- {
if strings.HasPrefix(parts[i], imagePrefix) {
chatParts = append(chatParts, openai.ChatMessagePart{
Type: openai.ChatMessagePartTypeImageURL,
ImageURL: &openai.ChatMessageImageURL{
URL: parts[i],
},
})
parts = parts[:i]
} else {
break
}
}
if len(parts) > 0 {
chatParts = append(chatParts, openai.ChatMessagePart{
Type: openai.ChatMessagePartTypeText,
Text: strings.Join(parts, "\n"),
})
}

slices.Reverse(chatParts)
return chatParts
}

func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
if err := c.ValidAuth(); err != nil {
if err := c.RetrieveAPIKey(ctx, env); err != nil {
Expand Down
38 changes: 38 additions & 0 deletions pkg/openai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,44 @@ import (
"github.com/hexops/valast"
)

func TestTextToMultiContent(t *testing.T) {
autogold.Expect([]openai.ChatMessagePart{{
Type: "text",
Text: "hi\n\n",
}}).Equal(t, textToMultiContent("hi\n\n"))

autogold.Expect([]openai.ChatMessagePart{
{
Type: "text",
Text: "hi",
},
{
Type: "image_url",
ImageURL: &openai.ChatMessageImageURL{URL: ""},
},
}).Equal(t, textToMultiContent("hi\n"))

autogold.Expect([]openai.ChatMessagePart{{
Type: "image_url",
ImageURL: &openai.ChatMessageImageURL{URL: ""},
}}).Equal(t, textToMultiContent(""))

autogold.Expect([]openai.ChatMessagePart{
{
Type: "text",
Text: "\none\ntwo",
},
{
Type: "image_url",
ImageURL: &openai.ChatMessageImageURL{URL: ""},
},
{
Type: "image_url",
ImageURL: &openai.ChatMessageImageURL{URL: ""},
},
}).Equal(t, textToMultiContent("\none\ntwo\n\n"))
}

func Test_appendMessage(t *testing.T) {
autogold.Expect(types.CompletionMessage{Content: []types.ContentPart{
{ToolCall: &types.CompletionToolCall{
Expand Down
4 changes: 2 additions & 2 deletions pkg/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ func writeSep(buf *strings.Builder, lastText bool) {
}
}

func (d Document) String() string {
func (d Document) Print() string {
buf := strings.Builder{}
lastText := false
for _, node := range d.Nodes {
Expand All @@ -274,7 +274,7 @@ func (d Document) String() string {
}
if node.ToolNode != nil {
writeSep(&buf, lastText)
buf.WriteString(node.ToolNode.Tool.String())
buf.WriteString(node.ToolNode.Tool.Print())
lastText = false
}
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ body
!metadata:first:package.json
foo=base
f
`).Equal(t, tools[0].String())
`).Equal(t, tools[0].Print())
}

func TestFormatWithBadInstruction(t *testing.T) {
Expand All @@ -316,9 +316,9 @@ func TestFormatWithBadInstruction(t *testing.T) {
Instructions: "foo: bar",
},
}
autogold.Expect("Name: foo\n===\nfoo: bar\n").Equal(t, input.String())
autogold.Expect("Name: foo\n===\nfoo: bar\n").Equal(t, input.Print())

tools, err := ParseTools(strings.NewReader(input.String()))
tools, err := ParseTools(strings.NewReader(input.Print()))
require.NoError(t, err)
if reflect.DeepEqual(input, tools[0]) {
t.Errorf("expected %v, got %v", input, tools[0])
Expand Down
15 changes: 14 additions & 1 deletion pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,17 @@ func (r *Runner) newDispatcher(ctx context.Context) dispatcher {
return newParallelDispatcher(ctx)
}

func idForToolCall(id string, state *engine.Return) string {
if state == nil || state.State == nil {
return id
}
tc, ok := state.State.Pending[id]
if !ok || tc.Index == nil {
return id
}
return fmt.Sprintf("%03d", *tc.Index)
}

func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, state *State, toolCategory engine.ToolCategory) (*State, []SubCallResult, error) {
var (
resultLock sync.Mutex
Expand Down Expand Up @@ -698,7 +709,9 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string,

// Sort the id so if sequential the results are predictable
ids := maps.Keys(state.Continuation.Calls)
sort.Strings(ids)
sort.Slice(ids, func(i, j int) bool {
return idForToolCall(ids[i], state.Continuation) < idForToolCall(ids[j], state.Continuation)
})

for _, id := range ids {
call := state.Continuation.Calls[id]
Expand Down
4 changes: 2 additions & 2 deletions pkg/sdkserver/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (s *server) listTools(w http.ResponseWriter, r *http.Request) {
// Don't print instructions
tool.Instructions = ""

lines = append(lines, tool.String())
lines = append(lines, tool.Print())
}

writeResponse(logger, w, map[string]any{"stdout": strings.Join(lines, "\n---\n")})
Expand Down Expand Up @@ -339,5 +339,5 @@ func (s *server) fmtDocument(w http.ResponseWriter, r *http.Request) {
return
}

writeResponse(logger, w, map[string]string{"stdout": doc.String()})
writeResponse(logger, w, map[string]string{"stdout": doc.Print()})
}
13 changes: 5 additions & 8 deletions pkg/sdkserver/types.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package sdkserver

import (
"encoding/json"
"maps"
"strings"
"time"

"github.com/gptscript-ai/gptscript/pkg/cache"
Expand Down Expand Up @@ -30,15 +30,12 @@ const (
type toolDefs []types.ToolDef

func (t toolDefs) String() string {
s := new(strings.Builder)
for i, tool := range t {
s.WriteString(tool.String())
if i != len(t)-1 {
s.WriteString("\n\n---\n\n")
}
data, err := json.Marshal(t)
if err != nil {
panic(err)
}

return s.String()
return "#!GPTSCRIPT" + string(data)
}

type (
Expand Down
8 changes: 8 additions & 0 deletions pkg/types/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,14 @@ func (t Tool) GetToolRefsFromNames(names []string) (result []ToolReference, _ er
}

func (t ToolDef) String() string {
data, err := json.Marshal([]any{t})
if err != nil {
panic(err)
}
return "#!GPTSCRIPT" + string(data)
}

func (t ToolDef) Print() string {
buf := &strings.Builder{}
if t.Parameters.GlobalModelName != "" {
_, _ = fmt.Fprintf(buf, "Global Model Name: %s\n", t.Parameters.GlobalModelName)
Expand Down
Loading