Skip to content

Commit

Permalink
Generate recursive input data struct
Browse files Browse the repository at this point in the history
  • Loading branch information
Peyton-Spencer committed Jan 26, 2025
1 parent cf489c4 commit 3e6ab84
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 7 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module github.com/omniaura/agentflow
go 1.23

require (
github.com/peyton-spencer/caseconv v0.1.1
github.com/peyton-spencer/caseconv v0.2.0
github.com/spf13/cobra v1.8.1
golang.org/x/sync v0.10.0
)
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/peyton-spencer/caseconv v0.1.1 h1:d8vGwInuHz+iSExyyV4UKCTY6+6nT8DybAg644ZiGsE=
github.com/peyton-spencer/caseconv v0.1.1/go.mod h1:ZnRiZGBCcE+32J4OTxPbFpEjzGNbywfEYHOIebQBA0A=
github.com/peyton-spencer/caseconv v0.2.0 h1:xVdG3AO6rEUQnSWz23Zv9mBfm3WTYSLzS1roMxCuFZo=
github.com/peyton-spencer/caseconv v0.2.0/go.mod h1:ZnRiZGBCcE+32J4OTxPbFpEjzGNbywfEYHOIebQBA0A=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
Expand Down
131 changes: 127 additions & 4 deletions pkg/ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package ast
import (
"bytes"
"fmt"
"log/slog"
"slices"
"strings"

Expand Down Expand Up @@ -112,19 +113,141 @@ func (p Prompt) Vars(content []byte, c caseconv.Case) (vars [][]byte, length int
return
}

type InputStruct struct {
TopLevel []InputNode
}

func (i1 InputStruct) Equal(i2 InputStruct) bool {
if len(i1.TopLevel) != len(i2.TopLevel) {
return false
}
for i := range i1.TopLevel {
if !i1.TopLevel[i].Equal(i2.TopLevel[i]) {
return false
}
}
return true
}

func (i1 InputNode) Equal(i2 InputNode) bool {
if !bytes.Equal(i1.Name, i2.Name) {
return false
}
if len(i1.Subnodes) != len(i2.Subnodes) {
return false
}
for i := range i1.Subnodes {
if !i1.Subnodes[i].Equal(i2.Subnodes[i]) {
return false
}
}
return true
}

func (ii InputStruct) String() string {
var buf strings.Builder
buf.WriteString("type Input struct {\n")
for _, n := range ii.TopLevel {
buf.WriteString(" ")
buf.Write(n.Name)
buf.WriteString(" ")
buf.WriteString(n.String())
buf.WriteString("\n")
}
buf.WriteString("}")
return buf.String()
}

type InputNode struct {
Name []byte
Subnodes []InputNode
}

func (n InputNode) String() string {
var buf strings.Builder
if len(n.Subnodes) == 0 {
buf.WriteString("string")
return buf.String()
}

buf.WriteString("struct {\n")
for _, sub := range n.Subnodes {
buf.WriteString(" ")
buf.Write(sub.Name)
buf.WriteString(" ")
buf.WriteString(sub.String())
buf.WriteString("\n")
}
buf.WriteString(" }")
return buf.String()
}

func (ii *InputStruct) insertVar(node token.T, content []byte, c caseconv.Case) {
name := bytes.Split(node.Get(content), []byte{'.'})
if len(name) == 1 {
nn := c.BytCase(name[0])
ii.TopLevel = append(ii.TopLevel, InputNode{
Name: nn,
})
return
}
nn := c.BytCase(name[0])
idx := slices.IndexFunc(ii.TopLevel, func(n InputNode) bool {
return bytes.Equal(n.Name, nn)
})
if idx == -1 {
ii.TopLevel = append(ii.TopLevel, InputNode{
Name: nn,
})
idx = len(ii.TopLevel) - 1
}
ii.TopLevel[idx].insertMultiLevelVar(name[1:], c)
}

func (n *InputNode) insertMultiLevelVar(name [][]byte, c caseconv.Case) {
if len(name) == 0 {
slog.Debug("0 len name reached")
return
}
nn := c.BytCase(name[0])
idx := slices.IndexFunc(n.Subnodes, func(n InputNode) bool {
return bytes.Equal(n.Name, nn)
})
if idx == -1 {
n.Subnodes = append(n.Subnodes, InputNode{
Name: nn,
})
idx = len(n.Subnodes) - 1
}
if len(name) == 1 {
return
}
n.Subnodes[idx].insertMultiLevelVar(name[1:], c)
}

func (p Prompt) GetInputs(content []byte, c caseconv.Case) (ii InputStruct, err error) {
for _, node := range p.Nodes {
switch node.Kind {
case kind.Var, kind.OptionalBlock:
ii.insertVar(node, content, c)
}
}
return
}

func (p1 Prompt) Equal(p2 Prompt) bool {
return p1.Title == p2.Title && p1.Nodes.Equal(p2.Nodes)
}

func NewFile(name string, content []byte) (f File, err error) {
tokens, err := token.Tokenize(content)
if err != nil {
return
}
if !strings.HasSuffix(name, ".af") {
err = fmt.Errorf("file does not have .af extension: %s", name)
return
}
tokens, err := token.Tokenize(content)
if err != nil {
return
}
f.Name = strings.TrimSuffix(name, ".af")
f.Content = content
f.Prompts, err = newPrompts(tokens)
Expand Down
47 changes: 47 additions & 0 deletions pkg/ast/ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/omniaura/agentflow/pkg/ast"
"github.com/omniaura/agentflow/pkg/token"
"github.com/omniaura/agentflow/pkg/token/kind"
"github.com/peyton-spencer/caseconv"
)

func TestMain(m *testing.M) {
Expand Down Expand Up @@ -125,3 +126,49 @@ func TestNewFile(t *testing.T) {
func joinLines(in ...[]byte) []byte {
return bytes.Join(in, []byte{'\n'})
}

func TestGetInputs(t *testing.T) {
type InputTestCase struct {
name string
want ast.InputStruct
wantErr error
content []byte
c caseconv.Case
}

tcases := []InputTestCase{
{
name: "UserVariables.af",
want: ast.InputStruct{
TopLevel: []ast.InputNode{
{
Name: []byte("User"),
Subnodes: []ast.InputNode{
{Name: []byte("Name")},
{Name: []byte("Email")},
{Name: []byte("Age")},
{Name: []byte("Subscription"), Subnodes: []ast.InputNode{
{Name: []byte("Plan")},
{Name: []byte("Status")},
}},
},
},
{Name: []byte("Message")},
},
},
content: []byte("<!user.name> <!user.email> <!user.age> <!message> <!user.subscription.plan> <!user.subscription.status>"),
c: caseconv.CaseCamel,
},
}

for _, tc := range tcases {
t.Run(tc.name, func(t *testing.T) {
f, err := ast.NewFile(tc.name, tc.content)
require.NoError(t, err)
ii, err := f.Prompts[0].GetInputs(tc.content, tc.c)
require.NoError(t, err)
require.Equal(t, tc.want, ii)
t.Logf("got:\n%s", ii.String())
})
}
}

0 comments on commit 3e6ab84

Please sign in to comment.