diff --git a/flow/pua/peerdb.go b/flow/pua/peerdb.go index 540065970f..7330cb9d09 100644 --- a/flow/pua/peerdb.go +++ b/flow/pua/peerdb.go @@ -3,6 +3,7 @@ package pua import ( "bytes" "fmt" + "math/big" "time" "github.com/google/uuid" @@ -72,6 +73,7 @@ func RegisterTypes(ls *lua.LState) { mt.RawSetString("__eq", ls.NewFunction(LuaBigIntEq)) mt.RawSetString("__le", ls.NewFunction(LuaBigIntLe)) mt.RawSetString("__lt", ls.NewFunction(LuaBigIntLt)) + mt.RawSetString("__unm", ls.NewFunction(LuaBigIntUnm)) mt = shared.LuaDecimal.NewMetatable(ls) mt.RawSetString("__index", ls.NewFunction(LuaDecimalIndex)) @@ -79,6 +81,25 @@ func RegisterTypes(ls *lua.LState) { mt.RawSetString("__eq", ls.NewFunction(LuaDecimalEq)) mt.RawSetString("__le", ls.NewFunction(LuaDecimalLe)) mt.RawSetString("__lt", ls.NewFunction(LuaDecimalLt)) + mt.RawSetString("__unm", ls.NewFunction(LuaDecimalUnm)) + mt.RawSetString("__add", ls.NewFunction(decimalBinop(func(d1 decimal.Decimal, d2 decimal.Decimal) decimal.Decimal { + return d1.Add(d2) + }))) + mt.RawSetString("__sub", ls.NewFunction(decimalBinop(func(d1 decimal.Decimal, d2 decimal.Decimal) decimal.Decimal { + return d1.Sub(d2) + }))) + mt.RawSetString("__mul", ls.NewFunction(decimalBinop(func(d1 decimal.Decimal, d2 decimal.Decimal) decimal.Decimal { + return d1.Mul(d2) + }))) + mt.RawSetString("__div", ls.NewFunction(decimalBinop(func(d1 decimal.Decimal, d2 decimal.Decimal) decimal.Decimal { + return d1.Div(d2) + }))) + mt.RawSetString("__mod", ls.NewFunction(decimalBinop(func(d1 decimal.Decimal, d2 decimal.Decimal) decimal.Decimal { + return d1.Mod(d2) + }))) + mt.RawSetString("__pow", ls.NewFunction(decimalBinop(func(d1 decimal.Decimal, d2 decimal.Decimal) decimal.Decimal { + return d1.Pow(d2) + }))) mt.RawSetString("__msgpack", ls.NewFunction(LuaDecimalString)) peerdb := ls.NewTable() @@ -330,19 +351,37 @@ func LuaUUID(ls *lua.LState) int { return 1 } -func LuaParseDecimal(ls *lua.LState) int { - switch v := ls.Get(1).(type) { +func LVAsDecimal(ls *lua.LState, lv lua.LValue) decimal.Decimal { + switch v := lv.(type) { case lua.LNumber: - ls.Push(shared.LuaDecimal.New(ls, decimal.NewFromFloat(float64(v)))) + return decimal.NewFromFloat(float64(v)) case lua.LString: d, err := decimal.NewFromString(string(v)) if err != nil { ls.RaiseError(err.Error()) } - ls.Push(shared.LuaDecimal.New(ls, d)) + return d + case *lua.LUserData: + switch v := v.Value.(type) { + case int64: + return decimal.NewFromInt(v) + case uint64: + return decimal.NewFromUint64(v) + case *big.Int: + return decimal.NewFromBigInt(v, 0) + case decimal.Decimal: + return v + default: + ls.RaiseError("cannot create decimal from %T", v) + } default: - ls.RaiseError("cannot create decimal from " + v.Type().String()) + ls.RaiseError("cannot create decimal from %s", v.Type()) } + return decimal.Decimal{} +} + +func LuaParseDecimal(ls *lua.LState) int { + ls.Push(shared.LuaDecimal.New(ls, LVAsDecimal(ls, ls.Get(1)))) return 1 } @@ -449,6 +488,12 @@ func LuaBigIntString(ls *lua.LState) int { return 1 } +func LuaBigIntUnm(ls *lua.LState) int { + bi := shared.LuaBigInt.StartMethod(ls) + ls.Push(shared.LuaBigInt.New(ls, new(big.Int).Neg(bi))) + return 1 +} + func LuaBigIntEq(ls *lua.LState) int { t1 := shared.LuaBigInt.StartMethod(ls) _, t2 := shared.LuaBigInt.Check(ls, 2) @@ -497,6 +542,19 @@ func LuaDecimalString(ls *lua.LState) int { return 1 } +func LuaDecimalUnm(ls *lua.LState) int { + num := shared.LuaDecimal.StartMethod(ls) + ls.Push(shared.LuaDecimal.New(ls, num.Neg())) + return 1 +} + +func decimalBinop(f func(d1 decimal.Decimal, d2 decimal.Decimal) decimal.Decimal) func(ls *lua.LState) int { + return func(ls *lua.LState) int { + ls.Push(shared.LuaDecimal.New(ls, f(LVAsDecimal(ls, ls.Get(1)), LVAsDecimal(ls, ls.Get(2))))) + return 1 + } +} + func LuaDecimalEq(ls *lua.LState) int { t1 := shared.LuaDecimal.StartMethod(ls) _, t2 := shared.LuaDecimal.Check(ls, 2)