diff --git a/chat/anthropic.go b/chat/anthropic.go index 8d79fdd..8ed89e6 100644 --- a/chat/anthropic.go +++ b/chat/anthropic.go @@ -2,6 +2,7 @@ package chat import ( "context" + "errors" "fmt" "github.com/kznrluk/aski/conv" "github.com/kznrluk/go-anthropic" @@ -47,6 +48,9 @@ func (a ap) rest(ctx context.Context, conv conv.Conversation) (string, error) { ) if err != nil { + if errors.Is(err, context.Canceled) { + return "", ErrCancelled + } return "", err } if len(rest.Content) == 0 { @@ -70,6 +74,9 @@ func (a ap) stream(ctx context.Context, conv conv.Conversation) (string, error) ) if err != nil { + if errors.Is(err, context.Canceled) { + return "", ErrCancelled + } return "", err } @@ -79,6 +86,8 @@ func (a ap) stream(ctx context.Context, conv conv.Conversation) (string, error) if err != nil { if err == io.EOF { break + } else if errors.Is(err, context.Canceled) { + return "", ErrCancelled } else { fmt.Printf("%s", err.Error()) return "", err diff --git a/chat/chat.go b/chat/chat.go index 12a282f..81fe132 100644 --- a/chat/chat.go +++ b/chat/chat.go @@ -2,6 +2,7 @@ package chat import ( "context" + "errors" "github.com/kznrluk/aski/config" "github.com/kznrluk/aski/conv" "os" @@ -18,6 +19,10 @@ type ( } ) +var ( + ErrCancelled = errors.New("cancelled") +) + func ProvideChat(model string, cfg config.Config) Chat { if strings.HasPrefix(model, "claude") { return NewAnthropic(cfg.AnthropicAPIKey) diff --git a/chat/openai.go b/chat/openai.go index 4e436e0..0f4e551 100644 --- a/chat/openai.go +++ b/chat/openai.go @@ -2,6 +2,7 @@ package chat import ( "context" + "errors" "fmt" "github.com/kznrluk/aski/conv" "github.com/sashabaranov/go-openai" @@ -61,6 +62,9 @@ func (o oai) rest(ctx context.Context, conv conv.Conversation) (string, error) { ) if err != nil { + if errors.Is(err, context.Canceled) { + return "", ErrCancelled + } return "", err } fmt.Printf("%s", resp.Choices[0].Message.Content) @@ -95,6 +99,9 @@ func (o oai) stream(ctx context.Context, conv conv.Conversation) (string, error) ) if err != nil { + if errors.Is(err, context.Canceled) { + return "", ErrCancelled + } return "", err } @@ -104,6 +111,8 @@ func (o oai) stream(ctx context.Context, conv conv.Conversation) (string, error) if err != nil { if err == io.EOF { break + } else if errors.Is(err, context.Canceled) { + return "", ErrCancelled } else { return "", err } diff --git a/conv/conversation.go b/conv/conversation.go index 43d25e9..93c70b5 100644 --- a/conv/conversation.go +++ b/conv/conversation.go @@ -131,6 +131,13 @@ func (c *conv) ChangeHead(sha1Partial string) (Message, error) { foundSha := false foundMessageIndex := -1 + if sha1Partial == "ROOT" { + for i := range c.Messages { + c.Messages[i].Head = false + } + return c.convertSystemToMessage(), nil + } + for i, message := range c.Messages { if strings.HasPrefix(message.Sha1, sha1Partial) { foundSha = true diff --git a/lib/dialog.go b/lib/dialog.go index a3c73e4..afab202 100644 --- a/lib/dialog.go +++ b/lib/dialog.go @@ -2,6 +2,7 @@ package lib import ( "context" + "errors" "fmt" "github.com/fatih/color" "github.com/kznrluk/aski/chat" @@ -100,6 +101,10 @@ func StartDialog(cfg config.Config, cv conv.Conversation, isRestMode bool, resto fmt.Printf("\n") data, err := cli.Retrieve(cv, isRestMode) if err != nil { + if errors.Is(err, chat.ErrCancelled) { + _, _ = cv.ChangeHead(last.ParentSha1) + continue + } fmt.Printf("\n%s", err.Error()) continue }