diff --git a/pkg/bytecode/code.go b/pkg/bytecode/code.go index be0206c9..2972d507 100644 --- a/pkg/bytecode/code.go +++ b/pkg/bytecode/code.go @@ -16,10 +16,23 @@ const ( // OpSetGlobal adds a symbol to the specified index in the symbol // table. OpSetGlobal + // OpAdd instructs the virtual machine to perform an addition. + OpAdd + // OpSubtract instructs the virtual machine to perform a subtraction. + OpSubtract ) -// ErrUnknownOp is returned when an unknown opcode is encountered. -var ErrUnknownOp = errors.New("unknown opcode") +var ( + // ErrInternal and errors wrapping ErrInternal report internal + // errors of the VM that should not occur during normal + // program execution. + ErrInternal = errors.New("internal error") + // ErrPanic and errors wrapping ErrPanic report runtime errors, such + // as an index out of bounds or a stack overflow. + ErrPanic = errors.New("user error") + // ErrUnknownOpcode is returned when an unknown opcode is encountered. + ErrUnknownOpcode = fmt.Errorf("%w: unknown opcode", ErrInternal) +) // definitions is a mapping of OpCode to OpDefinition. var definitions = map[Opcode]*OpDefinition{ @@ -29,6 +42,11 @@ var definitions = map[Opcode]*OpDefinition{ OpConstant: {"OpConstant", []int{2}}, OpGetGlobal: {"OpGetGlobal", []int{2}}, OpSetGlobal: {"OpSetGlobal", []int{2}}, + // Operations like OpAdd have no operand width because the virtual + // machine is expected to pop the values from the stack when reading + // this instruction. + OpAdd: {"OpAdd", nil}, + OpSubtract: {"OpSubtract", nil}, } // OpDefinition defines a name and expected operand width for each OpCode. @@ -92,7 +110,7 @@ func ReadUint16(ins Instructions) uint16 { func Lookup(op Opcode) (*OpDefinition, error) { def, ok := definitions[op] if !ok { - return nil, fmt.Errorf("%w: %d", ErrUnknownOp, op) + return nil, fmt.Errorf("%w: %d", ErrUnknownOpcode, op) } return def, nil } diff --git a/pkg/bytecode/compiler.go b/pkg/bytecode/compiler.go index 59850056..55af0051 100644 --- a/pkg/bytecode/compiler.go +++ b/pkg/bytecode/compiler.go @@ -1,15 +1,19 @@ package bytecode import ( - "errors" "fmt" "evylang.dev/evy/pkg/parser" ) -// ErrUndefinedVar is returned when a variable name cannot -// be resolved in the symbol table. -var ErrUndefinedVar = errors.New("undefined variable") +var ( + // ErrUndefinedVar is returned when a variable name cannot + // be resolved in the symbol table. + ErrUndefinedVar = fmt.Errorf("%w: undefined variable", ErrPanic) + // ErrUnknownOperator is returned when an operator cannot + // be resolved. + ErrUnknownOperator = fmt.Errorf("%w: unknown operator", ErrInternal) +) // Compiler is responsible for turning a parsed evy program into // bytecode. @@ -43,10 +47,12 @@ func (c *Compiler) Compile(node parser.Node) error { return c.compileDecl(node.Decl) case *parser.AssignmentStmt: return c.compileAssignment(node) + case *parser.BinaryExpression: + return c.compileBinaryExpression(node) case *parser.Var: return c.compileVar(node) case *parser.NumLiteral: - num := &numVal{V: node.Value} + num := numVal(node.Value) if err := c.emit(OpConstant, c.addConstant(num)); err != nil { return err } @@ -86,6 +92,24 @@ func (c *Compiler) emit(op Opcode, operands ...int) error { return nil } +func (c *Compiler) compileBinaryExpression(expr *parser.BinaryExpression) error { + if err := c.Compile(expr.Left); err != nil { + return err + } + if err := c.Compile(expr.Right); err != nil { + return err + } + switch expr.Op { + case parser.OP_PLUS: + return c.emit(OpAdd) + case parser.OP_MINUS: + return c.emit(OpSubtract) + // more operators to follow (*, /, %). + default: + return fmt.Errorf("%w %s", ErrUnknownOperator, expr.Op) + } +} + func (c *Compiler) compileProgram(prog *parser.Program) error { for _, s := range prog.Statements { if err := c.Compile(s); err != nil { diff --git a/pkg/bytecode/compiler_test.go b/pkg/bytecode/compiler_test.go deleted file mode 100644 index 7f3b0a99..00000000 --- a/pkg/bytecode/compiler_test.go +++ /dev/null @@ -1,88 +0,0 @@ -package bytecode - -import ( - "testing" - - "evylang.dev/evy/pkg/assert" - "evylang.dev/evy/pkg/parser" -) - -type compilerTestCase struct { - input string - expectedConstants []interface{} - expectedInstructions []Instructions -} - -func TestGlobalVarStatements(t *testing.T) { - tests := []compilerTestCase{ - { - input: "x := 1\nx = x", - expectedConstants: []interface{}{1}, - expectedInstructions: []Instructions{ - mustMake(t, OpConstant, 0), - mustMake(t, OpSetGlobal, 0), - mustMake(t, OpGetGlobal, 0), - mustMake(t, OpSetGlobal, 0), - }, - }, - { - input: "x := 2\nx = x", - expectedConstants: []interface{}{2}, - expectedInstructions: []Instructions{ - mustMake(t, OpConstant, 0), - mustMake(t, OpSetGlobal, 0), - mustMake(t, OpGetGlobal, 0), - mustMake(t, OpSetGlobal, 0), - }, - }, - } - for _, tt := range tests { - program, err := parser.Parse(tt.input, parser.Builtins{}) - assert.NoError(t, err, "parser error") - compiler := NewCompiler() - err = compiler.Compile(program) - assert.NoError(t, err, "compiler error") - bytecode := compiler.Bytecode() - assertInstructions(t, tt.expectedInstructions, bytecode.Instructions) - assertConstants(t, tt.expectedConstants, bytecode.Constants) - } -} - -func assertInstructions(t *testing.T, expected []Instructions, actual Instructions) { - t.Helper() - concatted := concatInstructions(expected) - assert.Equal(t, len(concatted), len(actual), "wrong instructions length") - for i, ins := range concatted { - assert.Equal(t, ins, actual[i], "wrong instruction %d", i) - } -} - -func concatInstructions(s []Instructions) Instructions { - out := Instructions{} - for _, ins := range s { - out = append(out, ins...) - } - return out -} - -func assertConstants(t *testing.T, expected []interface{}, actual []value) { - t.Helper() - assert.Equal(t, len(expected), len(actual), "wrong number of constants") - for i, constant := range expected { - switch constant := constant.(type) { - case int: - assertNumValue(t, float64(constant), actual[i]) - case float64: - assertNumValue(t, constant, actual[i]) - default: - t.Errorf("unknown constant type %v", constant) - } - } -} - -func assertNumValue(t *testing.T, expected float64, actual value) { - t.Helper() - result, ok := actual.(*numVal) - assert.Equal(t, true, ok, "object is not a NumVal. got=%T (%+v)", actual, actual) - assert.Equal(t, expected, result.V, "object has wrong value") -} diff --git a/pkg/bytecode/value.go b/pkg/bytecode/value.go index 60d3578d..44268de3 100644 --- a/pkg/bytecode/value.go +++ b/pkg/bytecode/value.go @@ -10,29 +10,20 @@ type value interface { Type() *parser.Type Equals(value) bool String() string - Set(value) } -type numVal struct { - V float64 -} - -func (n *numVal) Type() *parser.Type { return parser.NUM_TYPE } +type numVal float64 -func (n *numVal) String() string { return strconv.FormatFloat(n.V, 'f', -1, 64) } +func (n numVal) Type() *parser.Type { return parser.NUM_TYPE } -func (n *numVal) Equals(v value) bool { - n2, ok := v.(*numVal) - if !ok { - panic("internal error: Num.Equals called with non-Num value") - } - return n.V == n2.V +func (n numVal) String() string { + return strconv.FormatFloat(float64(n), 'f', -1, 64) } -func (n *numVal) Set(v value) { - n2, ok := v.(*numVal) +func (n numVal) Equals(v value) bool { + n2, ok := v.(numVal) if !ok { - panic("internal error: Num.Set called with with non-Num value") + panic("internal error: Num.Equals called with non-Num value") } - *n = *n2 + return n == n2 } diff --git a/pkg/bytecode/vm.go b/pkg/bytecode/vm.go index bb83a4e1..d7c4ffea 100644 --- a/pkg/bytecode/vm.go +++ b/pkg/bytecode/vm.go @@ -1,7 +1,7 @@ package bytecode import ( - "errors" + "fmt" ) const ( @@ -13,7 +13,7 @@ const ( ) // ErrStackOverflow is returned when the stack exceeds its size limit. -var ErrStackOverflow = errors.New("stack overflow") +var ErrStackOverflow = fmt.Errorf("%w: stack overflow", ErrPanic) // VM is responsible for executing evy programs from bytecode. type VM struct { @@ -48,21 +48,29 @@ func (vm *VM) Run() error { case OpConstant: constIndex := ReadUint16(vm.instructions[ip+1:]) ip += 2 - err := vm.push(vm.constants[constIndex]) - if err != nil { + if err := vm.push(vm.constants[constIndex]); err != nil { return err } case OpGetGlobal: globalIndex := ReadUint16(vm.instructions[ip+1:]) ip += 2 - err := vm.push(vm.globals[globalIndex]) - if err != nil { + if err := vm.push(vm.globals[globalIndex]); err != nil { return err } case OpSetGlobal: globalIndex := ReadUint16(vm.instructions[ip+1:]) ip += 2 vm.globals[globalIndex] = vm.pop() + case OpAdd: + right, left := vm.popBinaryNums() + if err := vm.push(numVal(left + right)); err != nil { + return err + } + case OpSubtract: + right, left := vm.popBinaryNums() + if err := vm.push(numVal(left - right)); err != nil { + return err + } } } return nil @@ -93,3 +101,25 @@ func (vm *VM) pop() value { vm.sp-- return o } + +// popBinaryNums pops the top two elements of the stack (the left +// and right sides of the binary expressions) as nums and returns both. +func (vm *VM) popBinaryNums() (float64, float64) { + // the right was compiled last, so is higher on the stack + // than the left + right := vm.popNumVal() + left := vm.popNumVal() + return float64(right), float64(left) +} + +// popNumVal pops an element from the stack and casts it to a num +// before returning the value. If elem is not a num it will error. +func (vm *VM) popNumVal() numVal { + elem := vm.pop() + val, ok := elem.(numVal) + if !ok { + panic(fmt.Errorf("%w: expected to pop numVal but got %s", + ErrInternal, elem.Type())) + } + return val +} diff --git a/pkg/bytecode/vm_test.go b/pkg/bytecode/vm_test.go index 52e34571..b59b48a5 100644 --- a/pkg/bytecode/vm_test.go +++ b/pkg/bytecode/vm_test.go @@ -7,33 +7,181 @@ import ( "evylang.dev/evy/pkg/parser" ) -func TestIntegerArithmetic(t *testing.T) { - tests := []vmTestCase{ - {"x := 1\nx = x", 1}, - {"y := 2\ny = y", 2}, +func TestVMGlobals(t *testing.T) { + tests := []testCase{ + { + name: "global assignment", + input: ` + x := 1 + x = x + `, + expectedStackTop: 1, + expectedConstants: []any{1}, + expectedInstructions: []Instructions{ + mustMake(t, OpConstant, 0), + mustMake(t, OpSetGlobal, 0), + mustMake(t, OpGetGlobal, 0), + mustMake(t, OpSetGlobal, 0), + }, + }, } for _, tt := range tests { - program, err := parser.Parse(tt.input, parser.Builtins{}) - assert.NoError(t, err, "parser error") - comp := NewCompiler() - err = comp.Compile(program) - assert.NoError(t, err, "compiler error") - vm := NewVM(comp.Bytecode()) - err = vm.Run() - assert.NoError(t, err, "vm error") - stackElem := vm.lastPoppedStackElem() - switch expected := tt.expected.(type) { + t.Run(tt.name, func(t *testing.T) { + program, err := parser.Parse(tt.input, parser.Builtins{}) + assert.NoError(t, err, "parser error") + comp := NewCompiler() + err = comp.Compile(program) + assert.NoError(t, err, "compiler error") + vm := NewVM(comp.Bytecode()) + err = vm.Run() + assert.NoError(t, err, "runtime error") + stackElem := vm.lastPoppedStackElem() + switch expected := tt.expectedStackTop.(type) { + case int: + assertNumValue(t, float64(expected), stackElem) + case float64: + assertNumValue(t, expected, stackElem) + default: + t.Errorf("unexpected object type %v", expected) + } + }) + } +} + +func TestVMArithmetic(t *testing.T) { + tests := []testCase{ + { + name: "addition", + input: ` + x := 2 + 1 + x = x + `, + expectedStackTop: 3, + expectedConstants: []any{2, 1}, + expectedInstructions: []Instructions{ + mustMake(t, OpConstant, 0), + mustMake(t, OpConstant, 1), + mustMake(t, OpAdd), + mustMake(t, OpSetGlobal, 0), + mustMake(t, OpGetGlobal, 0), + mustMake(t, OpSetGlobal, 0), + }, + }, + { + name: "subtraction", + input: ` + x := 2 - 1 + x = x + `, + expectedStackTop: 1, + expectedConstants: []any{2, 1}, + expectedInstructions: []Instructions{ + mustMake(t, OpConstant, 0), + mustMake(t, OpConstant, 1), + mustMake(t, OpSubtract), + mustMake(t, OpSetGlobal, 0), + mustMake(t, OpGetGlobal, 0), + mustMake(t, OpSetGlobal, 0), + }, + }, + { + name: "addition and subtraction", + input: ` + x := 2 - 1 + 3 + x = x + `, + expectedStackTop: 4, + expectedConstants: []any{2, 1, 3}, + expectedInstructions: []Instructions{ + mustMake(t, OpConstant, 0), + mustMake(t, OpConstant, 1), + mustMake(t, OpSubtract), + mustMake(t, OpConstant, 2), + mustMake(t, OpAdd), + mustMake(t, OpSetGlobal, 0), + mustMake(t, OpGetGlobal, 0), + mustMake(t, OpSetGlobal, 0), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + program, err := parser.Parse(tt.input, parser.Builtins{}) + assert.NoError(t, err, "parser error") + comp := NewCompiler() + err = comp.Compile(program) + assert.NoError(t, err, "compiler error") + bytecode := comp.Bytecode() + assertInstructions(t, tt.expectedInstructions, bytecode.Instructions) + assertConstants(t, tt.expectedConstants, bytecode.Constants) + vm := NewVM(bytecode) + err = vm.Run() + assert.NoError(t, err, "runtime error") + stackElem := vm.lastPoppedStackElem() + switch expected := tt.expectedStackTop.(type) { + case int: + assertNumValue(t, float64(expected), stackElem) + case float64: + assertNumValue(t, expected, stackElem) + default: + t.Errorf("unexpected object type %v", expected) + } + }) + } +} + +// testCase covers both the compiler and the VM. +type testCase struct { + name string + // input is an evy program + input string + // expectedStackTop is the result of popping the last + // element from the stack in the vm. + expectedStackTop any + // expectedConstants are the expected constants passed in the + // bytecode after compilation. + expectedConstants []any + // expectedInstructions are the expected compiler instructions + // passed in the bytecode after compilation + expectedInstructions []Instructions +} + +func assertInstructions(t *testing.T, expected []Instructions, actual Instructions) { + t.Helper() + concatted := concatInstructions(expected) + assert.Equal(t, len(concatted), len(actual), "wrong instructions length") + for i, ins := range concatted { + assert.Equal(t, ins, actual[i], "wrong instruction at %04d\nwant=\n%s\ngot=\n%s", + i, concatted, actual) + } +} + +func concatInstructions(s []Instructions) Instructions { + out := Instructions{} + for _, ins := range s { + out = append(out, ins...) + } + return out +} + +func assertConstants(t *testing.T, expected []any, actual []value) { + t.Helper() + assert.Equal(t, len(expected), len(actual), "wrong number of constants") + for i, constant := range expected { + switch constant := constant.(type) { case int: - assertNumValue(t, float64(expected), stackElem) + assertNumValue(t, float64(constant), actual[i]) case float64: - assertNumValue(t, expected, stackElem) + assertNumValue(t, constant, actual[i]) default: - t.Errorf("unexpected object type %v", expected) + t.Errorf("unknown constant type %v", constant) } } } -type vmTestCase struct { - input string - expected interface{} +func assertNumValue(t *testing.T, expected float64, actual value) { + t.Helper() + result, ok := actual.(numVal) + assert.Equal(t, true, ok, "object is not a NumVal. got=%T (%+v)", actual, actual) + assert.Equal(t, expected, float64(result), "object has wrong value") }