Skip to content

Commit

Permalink
Lua decimal math (#1671)
Browse files Browse the repository at this point in the history
  • Loading branch information
serprex authored May 6, 2024
1 parent bdbdfca commit 524d77e
Showing 1 changed file with 63 additions and 5 deletions.
68 changes: 63 additions & 5 deletions flow/pua/peerdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pua
import (
"bytes"
"fmt"
"math/big"
"time"

"github.com/google/uuid"
Expand Down Expand Up @@ -72,13 +73,33 @@ func RegisterTypes(ls *lua.LState) {
mt.RawSetString("__eq", ls.NewFunction(LuaBigIntEq))
mt.RawSetString("__le", ls.NewFunction(LuaBigIntLe))
mt.RawSetString("__lt", ls.NewFunction(LuaBigIntLt))
mt.RawSetString("__unm", ls.NewFunction(LuaBigIntUnm))

mt = shared.LuaDecimal.NewMetatable(ls)
mt.RawSetString("__index", ls.NewFunction(LuaDecimalIndex))
mt.RawSetString("__tostring", ls.NewFunction(LuaDecimalString))
mt.RawSetString("__eq", ls.NewFunction(LuaDecimalEq))
mt.RawSetString("__le", ls.NewFunction(LuaDecimalLe))
mt.RawSetString("__lt", ls.NewFunction(LuaDecimalLt))
mt.RawSetString("__unm", ls.NewFunction(LuaDecimalUnm))
mt.RawSetString("__add", ls.NewFunction(decimalBinop(func(d1 decimal.Decimal, d2 decimal.Decimal) decimal.Decimal {
return d1.Add(d2)
})))
mt.RawSetString("__sub", ls.NewFunction(decimalBinop(func(d1 decimal.Decimal, d2 decimal.Decimal) decimal.Decimal {
return d1.Sub(d2)
})))
mt.RawSetString("__mul", ls.NewFunction(decimalBinop(func(d1 decimal.Decimal, d2 decimal.Decimal) decimal.Decimal {
return d1.Mul(d2)
})))
mt.RawSetString("__div", ls.NewFunction(decimalBinop(func(d1 decimal.Decimal, d2 decimal.Decimal) decimal.Decimal {
return d1.Div(d2)
})))
mt.RawSetString("__mod", ls.NewFunction(decimalBinop(func(d1 decimal.Decimal, d2 decimal.Decimal) decimal.Decimal {
return d1.Mod(d2)
})))
mt.RawSetString("__pow", ls.NewFunction(decimalBinop(func(d1 decimal.Decimal, d2 decimal.Decimal) decimal.Decimal {
return d1.Pow(d2)
})))
mt.RawSetString("__msgpack", ls.NewFunction(LuaDecimalString))

peerdb := ls.NewTable()
Expand Down Expand Up @@ -330,19 +351,37 @@ func LuaUUID(ls *lua.LState) int {
return 1
}

func LuaParseDecimal(ls *lua.LState) int {
switch v := ls.Get(1).(type) {
func LVAsDecimal(ls *lua.LState, lv lua.LValue) decimal.Decimal {
switch v := lv.(type) {
case lua.LNumber:
ls.Push(shared.LuaDecimal.New(ls, decimal.NewFromFloat(float64(v))))
return decimal.NewFromFloat(float64(v))
case lua.LString:
d, err := decimal.NewFromString(string(v))
if err != nil {
ls.RaiseError(err.Error())
}
ls.Push(shared.LuaDecimal.New(ls, d))
return d
case *lua.LUserData:
switch v := v.Value.(type) {
case int64:
return decimal.NewFromInt(v)
case uint64:
return decimal.NewFromUint64(v)
case *big.Int:
return decimal.NewFromBigInt(v, 0)
case decimal.Decimal:
return v
default:
ls.RaiseError("cannot create decimal from %T", v)
}
default:
ls.RaiseError("cannot create decimal from " + v.Type().String())
ls.RaiseError("cannot create decimal from %s", v.Type())
}
return decimal.Decimal{}
}

func LuaParseDecimal(ls *lua.LState) int {
ls.Push(shared.LuaDecimal.New(ls, LVAsDecimal(ls, ls.Get(1))))
return 1
}

Expand Down Expand Up @@ -449,6 +488,12 @@ func LuaBigIntString(ls *lua.LState) int {
return 1
}

func LuaBigIntUnm(ls *lua.LState) int {
bi := shared.LuaBigInt.StartMethod(ls)
ls.Push(shared.LuaBigInt.New(ls, new(big.Int).Neg(bi)))
return 1
}

func LuaBigIntEq(ls *lua.LState) int {
t1 := shared.LuaBigInt.StartMethod(ls)
_, t2 := shared.LuaBigInt.Check(ls, 2)
Expand Down Expand Up @@ -497,6 +542,19 @@ func LuaDecimalString(ls *lua.LState) int {
return 1
}

func LuaDecimalUnm(ls *lua.LState) int {
num := shared.LuaDecimal.StartMethod(ls)
ls.Push(shared.LuaDecimal.New(ls, num.Neg()))
return 1
}

func decimalBinop(f func(d1 decimal.Decimal, d2 decimal.Decimal) decimal.Decimal) func(ls *lua.LState) int {
return func(ls *lua.LState) int {
ls.Push(shared.LuaDecimal.New(ls, f(LVAsDecimal(ls, ls.Get(1)), LVAsDecimal(ls, ls.Get(2)))))
return 1
}
}

func LuaDecimalEq(ls *lua.LState) int {
t1 := shared.LuaDecimal.StartMethod(ls)
_, t2 := shared.LuaDecimal.Check(ls, 2)
Expand Down

0 comments on commit 524d77e

Please sign in to comment.