From ece2f03b3c12106e668194b9cbb6ee4951a9f630 Mon Sep 17 00:00:00 2001 From: Rossiar Date: Mon, 12 Feb 2024 20:25:43 -0500 Subject: [PATCH] bytecode: Implement addition and subtraction These simple operations for integer mathematics will be expanded upon in the next change, which will add the multiplication, division and modulo operations. This change also standardises the compiler and vm test layouts to increase their readability. Co-authored-by: joshcarp Co-authored-by: pgmitche --- pkg/bytecode/code.go | 24 ++++- pkg/bytecode/compiler.go | 34 +++++- pkg/bytecode/compiler_test.go | 88 ---------------- pkg/bytecode/value.go | 25 ++--- pkg/bytecode/vm.go | 42 ++++++-- pkg/bytecode/vm_test.go | 188 ++++++++++++++++++++++++++++++---- 6 files changed, 262 insertions(+), 139 deletions(-) delete mode 100644 pkg/bytecode/compiler_test.go diff --git a/pkg/bytecode/code.go b/pkg/bytecode/code.go index be0206c9..2972d507 100644 --- a/pkg/bytecode/code.go +++ b/pkg/bytecode/code.go @@ -16,10 +16,23 @@ const ( // OpSetGlobal adds a symbol to the specified index in the symbol // table. OpSetGlobal + // OpAdd instructs the virtual machine to perform an addition. + OpAdd + // OpSubtract instructs the virtual machine to perform a subtraction. + OpSubtract ) -// ErrUnknownOp is returned when an unknown opcode is encountered. -var ErrUnknownOp = errors.New("unknown opcode") +var ( + // ErrInternal and errors wrapping ErrInternal report internal + // errors of the VM that should not occur during normal + // program execution. + ErrInternal = errors.New("internal error") + // ErrPanic and errors wrapping ErrPanic report runtime errors, such + // as an index out of bounds or a stack overflow. + ErrPanic = errors.New("user error") + // ErrUnknownOpcode is returned when an unknown opcode is encountered. + ErrUnknownOpcode = fmt.Errorf("%w: unknown opcode", ErrInternal) +) // definitions is a mapping of OpCode to OpDefinition. var definitions = map[Opcode]*OpDefinition{ @@ -29,6 +42,11 @@ var definitions = map[Opcode]*OpDefinition{ OpConstant: {"OpConstant", []int{2}}, OpGetGlobal: {"OpGetGlobal", []int{2}}, OpSetGlobal: {"OpSetGlobal", []int{2}}, + // Operations like OpAdd have no operand width because the virtual + // machine is expected to pop the values from the stack when reading + // this instruction. + OpAdd: {"OpAdd", nil}, + OpSubtract: {"OpSubtract", nil}, } // OpDefinition defines a name and expected operand width for each OpCode. @@ -92,7 +110,7 @@ func ReadUint16(ins Instructions) uint16 { func Lookup(op Opcode) (*OpDefinition, error) { def, ok := definitions[op] if !ok { - return nil, fmt.Errorf("%w: %d", ErrUnknownOp, op) + return nil, fmt.Errorf("%w: %d", ErrUnknownOpcode, op) } return def, nil } diff --git a/pkg/bytecode/compiler.go b/pkg/bytecode/compiler.go index 59850056..55af0051 100644 --- a/pkg/bytecode/compiler.go +++ b/pkg/bytecode/compiler.go @@ -1,15 +1,19 @@ package bytecode import ( - "errors" "fmt" "evylang.dev/evy/pkg/parser" ) -// ErrUndefinedVar is returned when a variable name cannot -// be resolved in the symbol table. -var ErrUndefinedVar = errors.New("undefined variable") +var ( + // ErrUndefinedVar is returned when a variable name cannot + // be resolved in the symbol table. + ErrUndefinedVar = fmt.Errorf("%w: undefined variable", ErrPanic) + // ErrUnknownOperator is returned when an operator cannot + // be resolved. + ErrUnknownOperator = fmt.Errorf("%w: unknown operator", ErrInternal) +) // Compiler is responsible for turning a parsed evy program into // bytecode. @@ -43,10 +47,12 @@ func (c *Compiler) Compile(node parser.Node) error { return c.compileDecl(node.Decl) case *parser.AssignmentStmt: return c.compileAssignment(node) + case *parser.BinaryExpression: + return c.compileBinaryExpression(node) case *parser.Var: return c.compileVar(node) case *parser.NumLiteral: - num := &numVal{V: node.Value} + num := numVal(node.Value) if err := c.emit(OpConstant, c.addConstant(num)); err != nil { return err } @@ -86,6 +92,24 @@ func (c *Compiler) emit(op Opcode, operands ...int) error { return nil } +func (c *Compiler) compileBinaryExpression(expr *parser.BinaryExpression) error { + if err := c.Compile(expr.Left); err != nil { + return err + } + if err := c.Compile(expr.Right); err != nil { + return err + } + switch expr.Op { + case parser.OP_PLUS: + return c.emit(OpAdd) + case parser.OP_MINUS: + return c.emit(OpSubtract) + // more operators to follow (*, /, %). + default: + return fmt.Errorf("%w %s", ErrUnknownOperator, expr.Op) + } +} + func (c *Compiler) compileProgram(prog *parser.Program) error { for _, s := range prog.Statements { if err := c.Compile(s); err != nil { diff --git a/pkg/bytecode/compiler_test.go b/pkg/bytecode/compiler_test.go deleted file mode 100644 index 7f3b0a99..00000000 --- a/pkg/bytecode/compiler_test.go +++ /dev/null @@ -1,88 +0,0 @@ -package bytecode - -import ( - "testing" - - "evylang.dev/evy/pkg/assert" - "evylang.dev/evy/pkg/parser" -) - -type compilerTestCase struct { - input string - expectedConstants []interface{} - expectedInstructions []Instructions -} - -func TestGlobalVarStatements(t *testing.T) { - tests := []compilerTestCase{ - { - input: "x := 1\nx = x", - expectedConstants: []interface{}{1}, - expectedInstructions: []Instructions{ - mustMake(t, OpConstant, 0), - mustMake(t, OpSetGlobal, 0), - mustMake(t, OpGetGlobal, 0), - mustMake(t, OpSetGlobal, 0), - }, - }, - { - input: "x := 2\nx = x", - expectedConstants: []interface{}{2}, - expectedInstructions: []Instructions{ - mustMake(t, OpConstant, 0), - mustMake(t, OpSetGlobal, 0), - mustMake(t, OpGetGlobal, 0), - mustMake(t, OpSetGlobal, 0), - }, - }, - } - for _, tt := range tests { - program, err := parser.Parse(tt.input, parser.Builtins{}) - assert.NoError(t, err, "parser error") - compiler := NewCompiler() - err = compiler.Compile(program) - assert.NoError(t, err, "compiler error") - bytecode := compiler.Bytecode() - assertInstructions(t, tt.expectedInstructions, bytecode.Instructions) - assertConstants(t, tt.expectedConstants, bytecode.Constants) - } -} - -func assertInstructions(t *testing.T, expected []Instructions, actual Instructions) { - t.Helper() - concatted := concatInstructions(expected) - assert.Equal(t, len(concatted), len(actual), "wrong instructions length") - for i, ins := range concatted { - assert.Equal(t, ins, actual[i], "wrong instruction %d", i) - } -} - -func concatInstructions(s []Instructions) Instructions { - out := Instructions{} - for _, ins := range s { - out = append(out, ins...) - } - return out -} - -func assertConstants(t *testing.T, expected []interface{}, actual []value) { - t.Helper() - assert.Equal(t, len(expected), len(actual), "wrong number of constants") - for i, constant := range expected { - switch constant := constant.(type) { - case int: - assertNumValue(t, float64(constant), actual[i]) - case float64: - assertNumValue(t, constant, actual[i]) - default: - t.Errorf("unknown constant type %v", constant) - } - } -} - -func assertNumValue(t *testing.T, expected float64, actual value) { - t.Helper() - result, ok := actual.(*numVal) - assert.Equal(t, true, ok, "object is not a NumVal. got=%T (%+v)", actual, actual) - assert.Equal(t, expected, result.V, "object has wrong value") -} diff --git a/pkg/bytecode/value.go b/pkg/bytecode/value.go index 60d3578d..44268de3 100644 --- a/pkg/bytecode/value.go +++ b/pkg/bytecode/value.go @@ -10,29 +10,20 @@ type value interface { Type() *parser.Type Equals(value) bool String() string - Set(value) } -type numVal struct { - V float64 -} - -func (n *numVal) Type() *parser.Type { return parser.NUM_TYPE } +type numVal float64 -func (n *numVal) String() string { return strconv.FormatFloat(n.V, 'f', -1, 64) } +func (n numVal) Type() *parser.Type { return parser.NUM_TYPE } -func (n *numVal) Equals(v value) bool { - n2, ok := v.(*numVal) - if !ok { - panic("internal error: Num.Equals called with non-Num value") - } - return n.V == n2.V +func (n numVal) String() string { + return strconv.FormatFloat(float64(n), 'f', -1, 64) } -func (n *numVal) Set(v value) { - n2, ok := v.(*numVal) +func (n numVal) Equals(v value) bool { + n2, ok := v.(numVal) if !ok { - panic("internal error: Num.Set called with with non-Num value") + panic("internal error: Num.Equals called with non-Num value") } - *n = *n2 + return n == n2 } diff --git a/pkg/bytecode/vm.go b/pkg/bytecode/vm.go index bb83a4e1..d7c4ffea 100644 --- a/pkg/bytecode/vm.go +++ b/pkg/bytecode/vm.go @@ -1,7 +1,7 @@ package bytecode import ( - "errors" + "fmt" ) const ( @@ -13,7 +13,7 @@ const ( ) // ErrStackOverflow is returned when the stack exceeds its size limit. -var ErrStackOverflow = errors.New("stack overflow") +var ErrStackOverflow = fmt.Errorf("%w: stack overflow", ErrPanic) // VM is responsible for executing evy programs from bytecode. type VM struct { @@ -48,21 +48,29 @@ func (vm *VM) Run() error { case OpConstant: constIndex := ReadUint16(vm.instructions[ip+1:]) ip += 2 - err := vm.push(vm.constants[constIndex]) - if err != nil { + if err := vm.push(vm.constants[constIndex]); err != nil { return err } case OpGetGlobal: globalIndex := ReadUint16(vm.instructions[ip+1:]) ip += 2 - err := vm.push(vm.globals[globalIndex]) - if err != nil { + if err := vm.push(vm.globals[globalIndex]); err != nil { return err } case OpSetGlobal: globalIndex := ReadUint16(vm.instructions[ip+1:]) ip += 2 vm.globals[globalIndex] = vm.pop() + case OpAdd: + right, left := vm.popBinaryNums() + if err := vm.push(numVal(left + right)); err != nil { + return err + } + case OpSubtract: + right, left := vm.popBinaryNums() + if err := vm.push(numVal(left - right)); err != nil { + return err + } } } return nil @@ -93,3 +101,25 @@ func (vm *VM) pop() value { vm.sp-- return o } + +// popBinaryNums pops the top two elements of the stack (the left +// and right sides of the binary expressions) as nums and returns both. +func (vm *VM) popBinaryNums() (float64, float64) { + // the right was compiled last, so is higher on the stack + // than the left + right := vm.popNumVal() + left := vm.popNumVal() + return float64(right), float64(left) +} + +// popNumVal pops an element from the stack and casts it to a num +// before returning the value. If elem is not a num it will error. +func (vm *VM) popNumVal() numVal { + elem := vm.pop() + val, ok := elem.(numVal) + if !ok { + panic(fmt.Errorf("%w: expected to pop numVal but got %s", + ErrInternal, elem.Type())) + } + return val +} diff --git a/pkg/bytecode/vm_test.go b/pkg/bytecode/vm_test.go index 52e34571..b59b48a5 100644 --- a/pkg/bytecode/vm_test.go +++ b/pkg/bytecode/vm_test.go @@ -7,33 +7,181 @@ import ( "evylang.dev/evy/pkg/parser" ) -func TestIntegerArithmetic(t *testing.T) { - tests := []vmTestCase{ - {"x := 1\nx = x", 1}, - {"y := 2\ny = y", 2}, +func TestVMGlobals(t *testing.T) { + tests := []testCase{ + { + name: "global assignment", + input: ` + x := 1 + x = x + `, + expectedStackTop: 1, + expectedConstants: []any{1}, + expectedInstructions: []Instructions{ + mustMake(t, OpConstant, 0), + mustMake(t, OpSetGlobal, 0), + mustMake(t, OpGetGlobal, 0), + mustMake(t, OpSetGlobal, 0), + }, + }, } for _, tt := range tests { - program, err := parser.Parse(tt.input, parser.Builtins{}) - assert.NoError(t, err, "parser error") - comp := NewCompiler() - err = comp.Compile(program) - assert.NoError(t, err, "compiler error") - vm := NewVM(comp.Bytecode()) - err = vm.Run() - assert.NoError(t, err, "vm error") - stackElem := vm.lastPoppedStackElem() - switch expected := tt.expected.(type) { + t.Run(tt.name, func(t *testing.T) { + program, err := parser.Parse(tt.input, parser.Builtins{}) + assert.NoError(t, err, "parser error") + comp := NewCompiler() + err = comp.Compile(program) + assert.NoError(t, err, "compiler error") + vm := NewVM(comp.Bytecode()) + err = vm.Run() + assert.NoError(t, err, "runtime error") + stackElem := vm.lastPoppedStackElem() + switch expected := tt.expectedStackTop.(type) { + case int: + assertNumValue(t, float64(expected), stackElem) + case float64: + assertNumValue(t, expected, stackElem) + default: + t.Errorf("unexpected object type %v", expected) + } + }) + } +} + +func TestVMArithmetic(t *testing.T) { + tests := []testCase{ + { + name: "addition", + input: ` + x := 2 + 1 + x = x + `, + expectedStackTop: 3, + expectedConstants: []any{2, 1}, + expectedInstructions: []Instructions{ + mustMake(t, OpConstant, 0), + mustMake(t, OpConstant, 1), + mustMake(t, OpAdd), + mustMake(t, OpSetGlobal, 0), + mustMake(t, OpGetGlobal, 0), + mustMake(t, OpSetGlobal, 0), + }, + }, + { + name: "subtraction", + input: ` + x := 2 - 1 + x = x + `, + expectedStackTop: 1, + expectedConstants: []any{2, 1}, + expectedInstructions: []Instructions{ + mustMake(t, OpConstant, 0), + mustMake(t, OpConstant, 1), + mustMake(t, OpSubtract), + mustMake(t, OpSetGlobal, 0), + mustMake(t, OpGetGlobal, 0), + mustMake(t, OpSetGlobal, 0), + }, + }, + { + name: "addition and subtraction", + input: ` + x := 2 - 1 + 3 + x = x + `, + expectedStackTop: 4, + expectedConstants: []any{2, 1, 3}, + expectedInstructions: []Instructions{ + mustMake(t, OpConstant, 0), + mustMake(t, OpConstant, 1), + mustMake(t, OpSubtract), + mustMake(t, OpConstant, 2), + mustMake(t, OpAdd), + mustMake(t, OpSetGlobal, 0), + mustMake(t, OpGetGlobal, 0), + mustMake(t, OpSetGlobal, 0), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + program, err := parser.Parse(tt.input, parser.Builtins{}) + assert.NoError(t, err, "parser error") + comp := NewCompiler() + err = comp.Compile(program) + assert.NoError(t, err, "compiler error") + bytecode := comp.Bytecode() + assertInstructions(t, tt.expectedInstructions, bytecode.Instructions) + assertConstants(t, tt.expectedConstants, bytecode.Constants) + vm := NewVM(bytecode) + err = vm.Run() + assert.NoError(t, err, "runtime error") + stackElem := vm.lastPoppedStackElem() + switch expected := tt.expectedStackTop.(type) { + case int: + assertNumValue(t, float64(expected), stackElem) + case float64: + assertNumValue(t, expected, stackElem) + default: + t.Errorf("unexpected object type %v", expected) + } + }) + } +} + +// testCase covers both the compiler and the VM. +type testCase struct { + name string + // input is an evy program + input string + // expectedStackTop is the result of popping the last + // element from the stack in the vm. + expectedStackTop any + // expectedConstants are the expected constants passed in the + // bytecode after compilation. + expectedConstants []any + // expectedInstructions are the expected compiler instructions + // passed in the bytecode after compilation + expectedInstructions []Instructions +} + +func assertInstructions(t *testing.T, expected []Instructions, actual Instructions) { + t.Helper() + concatted := concatInstructions(expected) + assert.Equal(t, len(concatted), len(actual), "wrong instructions length") + for i, ins := range concatted { + assert.Equal(t, ins, actual[i], "wrong instruction at %04d\nwant=\n%s\ngot=\n%s", + i, concatted, actual) + } +} + +func concatInstructions(s []Instructions) Instructions { + out := Instructions{} + for _, ins := range s { + out = append(out, ins...) + } + return out +} + +func assertConstants(t *testing.T, expected []any, actual []value) { + t.Helper() + assert.Equal(t, len(expected), len(actual), "wrong number of constants") + for i, constant := range expected { + switch constant := constant.(type) { case int: - assertNumValue(t, float64(expected), stackElem) + assertNumValue(t, float64(constant), actual[i]) case float64: - assertNumValue(t, expected, stackElem) + assertNumValue(t, constant, actual[i]) default: - t.Errorf("unexpected object type %v", expected) + t.Errorf("unknown constant type %v", constant) } } } -type vmTestCase struct { - input string - expected interface{} +func assertNumValue(t *testing.T, expected float64, actual value) { + t.Helper() + result, ok := actual.(numVal) + assert.Equal(t, true, ok, "object is not a NumVal. got=%T (%+v)", actual, actual) + assert.Equal(t, expected, float64(result), "object has wrong value") }