diff --git a/common/common.go b/common/common.go index cfbe5ab..6ca5102 100644 --- a/common/common.go +++ b/common/common.go @@ -293,3 +293,36 @@ func (f *ArrayFlag) Set(value string) error { *f = append(*f, value) return nil } + +// FlagsParse parse args to map +func FlagsParse(args []string, noValArg Set[string], schema map[string]string) map[string]string { + keyPos := 0 // position arg + keyGen := func() string { + keyPos++ + return fmt.Sprintf("-pos%d", keyPos) + } + resultMap := make(map[string]string) + var key string + for _, arg := range args { + switch { + case len(arg) > 2 && arg[:2] == "--": + key = arg[2:] + resultMap[key] = "" + case len(arg) > 1 && arg[0] == '-': + d, ok := schema[arg[1:]] + if ok && len(d) > 0 { + key = d + } else { + key = arg[1:] + } + resultMap[key] = "" + case len(arg) > 0 && arg[0] != '-': + if noValArg.Has(key) || key == "" { + key = keyGen() + } + resultMap[key] = arg + key = "" + } + } + return resultMap +} diff --git a/common/common_test.go b/common/common_test.go new file mode 100644 index 0000000..b985ee5 --- /dev/null +++ b/common/common_test.go @@ -0,0 +1,31 @@ +package common + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestFlagsParse(t *testing.T) { + args := []string{ + "-a", + "pos-arg1", + "-b", + "vb", + "pos-arg2", + } + noValArg := MakeSet[string](2) + noValArg.Insert("append") + schema := map[string]string{ + "a": "append", + "b": "block", + } + actual := FlagsParse(args, noValArg, schema) + + expected := map[string]string{ + "-pos1": "pos-arg1", + "-pos2": "pos-arg2", + "append": "", + "block": "vb", + } + assert.Equal(t, expected, actual) +} diff --git a/internal/specialcmd/specialcmd.go b/internal/specialcmd/specialcmd.go index 6c7e8c7..7959a0e 100644 --- a/internal/specialcmd/specialcmd.go +++ b/internal/specialcmd/specialcmd.go @@ -46,7 +46,7 @@ type cellStatus struct { func Parse(msg kernel.Message, goExec *goexec.State, execute bool, codeLines []string, usedLines Set[int]) (err error) { status := &cellStatus{} for lineNum := 0; lineNum < len(codeLines); lineNum++ { - if _, found := usedLines[lineNum]; found { + if usedLines.Has(lineNum) { continue } line := codeLines[lineNum] @@ -65,9 +65,19 @@ func Parse(msg kernel.Message, goExec *goexec.State, execute bool, codeLines []s if execute { switch cmdType { case '%': - err = execInternal(msg, goExec, cmdStr, status) - if err != nil { - return + parts := splitCmd(cmdStr) + // optimize... + if len(parts) > 0 && parts[0] == "writefile" { + cmdBody := parseCmdBody(codeLines, lineNum, usedLines) + err = execWriteFile(msg, goExec, parts[1:], cmdBody) + if err != nil { + return + } + } else { + err = execInternal(msg, goExec, cmdStr, status) + if err != nil { + return + } } case '!': err = execShell(msg, goExec, cmdStr, status) @@ -95,7 +105,7 @@ func Parse(msg kernel.Message, goExec *goexec.State, execute bool, codeLines []s func joinLine(lines []string, fromLine int, usedLines Set[int]) (cmdStr string) { for ; fromLine < len(lines); fromLine++ { cmdStr += lines[fromLine] - usedLines[fromLine] = struct{}{} + usedLines.Insert(fromLine) if cmdStr[len(cmdStr)-1] != '\\' { return } @@ -104,6 +114,23 @@ func joinLine(lines []string, fromLine int, usedLines Set[int]) (cmdStr string) return } +// parseCmdBody starts from fromLine and joins consecutive lines until the line start with magic symbol( % ! ) +// +// It returns the joined lines with the '\n', and appends the used lines (including fromLine) to usedLines. +func parseCmdBody(lines []string, fromLine int, usedLines Set[int]) (cmdBody string) { + usedLines.Insert(fromLine) + fromLine++ + for ; fromLine < len(lines); fromLine++ { + if len(lines[fromLine]) > 0 && (lines[fromLine][0] == '%' || lines[fromLine][0] == '!') { + return + } + cmdBody += lines[fromLine] + cmdBody += "\n" + usedLines.Insert(fromLine) + } + return +} + // execInternal executes internal configuration commands, see HelpMessage for details. // // It only returns errors for system errors that will lead to the kernel restart. Syntax errors @@ -275,6 +302,39 @@ func execInternal(msg kernel.Message, goExec *goexec.State, cmdStr string, statu return nil } +// execWriteFile write cell body to file +func execWriteFile(msg kernel.Message, goExec *goexec.State, args []string, cmdBody string) error { + // parse arg + noValArg := MakeSet[string](2) + noValArg.Insert("append") + schema := map[string]string{"a": "append"} + parse := FlagsParse(args, noValArg, schema) + _, appendMode := parse["append"] + filename, hasFileName := parse["-pos1"] + if !hasFileName { + filename = goExec.UniqueID + ".out" + } + + // do write + fileFlag := os.O_RDWR | os.O_CREATE + if appendMode { + fileFlag |= os.O_APPEND + } else { + fileFlag |= os.O_TRUNC + } + file, err := os.OpenFile(filename, fileFlag, 0666) + if err != nil { + return err + } + defer file.Close() + + _, err = file.WriteString(cmdBody) + if err != nil { + return err + } + return kernel.PublishWriteStream(msg, kernel.StreamStdout, "write to "+filename+" success\n") +} + // execInternal executes internal configuration commands, see HelpMessage for details. // // It only returns errors for system errors that will lead to the kernel restart. Syntax errors diff --git a/internal/specialcmd/specialcmd_test.go b/internal/specialcmd/specialcmd_test.go index 59261d2..854557f 100644 --- a/internal/specialcmd/specialcmd_test.go +++ b/internal/specialcmd/specialcmd_test.go @@ -61,3 +61,61 @@ func TestDirEnv(t *testing.T) { assert.Equal(t, "/tmp", os.Getenv(protocol.GONB_DIR_ENV)) require.NoError(t, s.Stop()) } + +func TestMagicWrite(t *testing.T) { + s := newEmptyState(t) + + expected := `fmt.Println("1") +fmt.Println("2") +// !*cat main.go +` + + type TestCase struct { + appendMode bool + filename, src, fileContent string + } + srcGen := func(testCase *TestCase) { + var appendArg string + if testCase.appendMode { + appendArg = " -a " + } + testCase.src = `%writefile ` + appendArg + testCase.filename + "\n" + expected + "%%\nfmt.Println(1)" + } + + // build test cases + testCases := []*TestCase{ + {false, "", "", expected}, + {true, "", "", strings.Repeat(expected, 2)}, + {false, "/tmp/TestMagicWrite.log", "", expected}, + } + for _, testCase := range testCases { + srcGen(testCase) + } + + // run test cases + fileClean := MakeSet[string](4) + defer func() { + for filename := range fileClean { + defer os.Remove(filename) + } + }() + for idx, testCase := range testCases { + t.Run(fmt.Sprintf("test-case-%d", idx), func(t *testing.T) { + filename := testCase.filename + if filename == "" { + filename = s.UniqueID + ".out" + } + fileClean.Insert(filename) + + var msg kernel.Message + usedLines := MakeSet[int]() + lines := strings.Split(testCase.src, "\n") + err := Parse(msg, s, true, lines, usedLines) + require.NoError(t, err) + + fileBytes, err := os.ReadFile(filename) + require.NoError(t, err) + assert.Equal(t, testCase.fileContent, string(fileBytes)) + }) + } +}