diff --git a/flow/pua/flatbuffers_binaryarray.go b/flow/pua/flatbuffers_binaryarray.go index e296c871bf..50cd5ad46c 100644 --- a/flow/pua/flatbuffers_binaryarray.go +++ b/flow/pua/flatbuffers_binaryarray.go @@ -27,7 +27,7 @@ func FlatBuffers_BinaryArray_Loader(ls *lua.LState) int { } func BinaryArrayNew(ls *lua.LState) int { - lval := ls.Get(-1) + lval := ls.Get(1) var ba BinaryArray switch val := lval.(type) { case lua.LString: diff --git a/flow/pua/flatbuffers_builder.go b/flow/pua/flatbuffers_builder.go index b88b7f8f23..024c1d0a1c 100644 --- a/flow/pua/flatbuffers_builder.go +++ b/flow/pua/flatbuffers_builder.go @@ -17,10 +17,12 @@ type Builder struct { minalign uint8 } -func (b *Builder) EndVector(vectorSize int) int { +func (b *Builder) EndVector(ls *lua.LState, vectorSize int) int { if !b.nested { - panic("EndVector called outside nested context") + ls.RaiseError("EndVector called outside nested context") + return 0 } + b.nested = false b.PlaceU64(uint64(vectorSize), uint32n) return b.Offset() } @@ -78,10 +80,26 @@ func (b *Builder) PrependU64(n N, x uint64) { } func (b *Builder) PrependSlot(ls *lua.LState, n N, slotnum int, x lua.LValue, d lua.LValue) { - // TODO implement __eq for U64/I64 if !ls.Equal(x, d) { + if xud, ok := x.(*lua.LUserData); ok { + // Need to check int64/number because flatbuffers passes default as 0 + // but Lua only calls __eq when both operands are same type + if dn, ok := d.(lua.LNumber); ok { + switch xv := xud.Value.(type) { + case int64: + if xv == int64(dn) { + return + } + case uint64: + if xv == uint64(dn) { + return + } + } + } + } + b.Prepend(ls, n, x) - b.Slot(slotnum) + b.Slot(ls, slotnum) } } @@ -108,9 +126,10 @@ func (b *Builder) PrependVOffsetT(off uint16) { b.PlaceU64(uint64(off), uint16n) } -func (b *Builder) Slot(slotnum int) { +func (b *Builder) Slot(ls *lua.LState, slotnum int) { if !b.nested { - panic("Slot called outside nested context") + ls.RaiseError("Slot called outside nested context") + return } for slotnum < len(b.currentVT) { b.currentVT = append(b.currentVT, 0) @@ -300,11 +319,11 @@ func BuilderStartObject(ls *lua.LState) int { ls.RaiseError("StartObject called inside nested context") return 0 } + b.nested = true numFields := int(ls.CheckNumber(2)) b.currentVT = make([]int, numFields) b.objectEnd = b.Offset() - b.nested = true return 0 } @@ -317,8 +336,10 @@ func BuilderWriteVtable(ls *lua.LState) int { func BuilderEndObject(ls *lua.LState) int { b := LuaBuilder.StartMeta(ls) if !b.nested { - panic("EndObject called outside nested context") + ls.RaiseError("EndObject called outside nested context") + return 0 } + b.nested = false ls.Push(lua.LNumber(b.WriteVtable(ls))) return 1 } @@ -380,12 +401,8 @@ func BuilderStartVector(ls *lua.LState) int { func BuilderEndVector(ls *lua.LState) int { b := LuaBuilder.StartMeta(ls) - if !b.nested { - ls.RaiseError("EndVector called outside nested context") - } - b.nested = false - b.PlaceU64(uint64(ls.CheckNumber(2)), uint32n) - ls.Push(lua.LNumber(b.Offset())) + size := int(ls.CheckNumber(2)) + ls.Push(lua.LNumber(b.EndVector(ls, size))) return 1 } @@ -404,13 +421,13 @@ func BuilderCreateString(ls *lua.LState) int { b.head -= lens copy(b.ba.data[b.head:], s) - return b.EndVector(lens) + return b.EndVector(ls, lens) } func BuilderSlot(ls *lua.LState) int { b := LuaBuilder.StartMeta(ls) slotnum := int(ls.CheckNumber(2)) - b.Slot(slotnum) + b.Slot(ls, slotnum) return 0 } @@ -516,7 +533,7 @@ func BuilderPrependStructSlot(ls *lua.LState) int { if x != b.Offset() { ls.RaiseError("Tried to write a Struct at an Offset that is different from the current Offset of the Builder.") } else { - b.Slot(int(ls.CheckNumber(2))) + b.Slot(ls, int(ls.CheckNumber(2))) } } return 0 @@ -528,7 +545,7 @@ func BuilderPrependUOffsetTRelativeSlot(ls *lua.LState) int { d := int(ls.CheckNumber(4)) if x != d { b.PrependOffsetTRelative(ls, x, uint32n) - b.Slot(int(ls.CheckNumber(2))) + b.Slot(ls, int(ls.CheckNumber(2))) } return 0 } diff --git a/flow/pua/flatbuffers_numtypes.go b/flow/pua/flatbuffers_numtypes.go index b174bee873..32a23fc205 100644 --- a/flow/pua/flatbuffers_numtypes.go +++ b/flow/pua/flatbuffers_numtypes.go @@ -59,10 +59,10 @@ func (n *N) Pack(ls *lua.LState, buf []byte, val lua.LValue) { switch lv := val.(type) { case *lua.LUserData: switch v := lv.Value.(type) { - case NI64: - n.PackU64(buf, uint64(v.val)) - case NU64: - n.PackU64(buf, v.val) + case int64: + n.PackU64(buf, uint64(v)) + case uint64: + n.PackU64(buf, v) default: n.PackU64(buf, 0) } @@ -88,10 +88,10 @@ func (n *N) Pack(ls *lua.LState, buf []byte, val lua.LValue) { switch lv := val.(type) { case *lua.LUserData: switch v := lv.Value.(type) { - case NI64: - n.PackU64(buf, math.Float64bits(float64(v.val))) - case NU64: - n.PackU64(buf, math.Float64bits(float64(v.val))) + case int64: + n.PackU64(buf, math.Float64bits(float64(v))) + case uint64: + n.PackU64(buf, math.Float64bits(float64(v))) default: n.PackU64(buf, 0) } @@ -146,9 +146,9 @@ func (n *N) Unpack(ls *lua.LState, buf []byte) lua.LValue { case 8: u64 := binary.LittleEndian.Uint64(buf) if n.signed { - return LuaNI64.New(ls, NI64{int64(u64)}) + return LuaI64.New(ls, int64(u64)) } else { - return LuaNU64.New(ls, NU64{u64}) + return LuaU64.New(ls, u64) } } case tyfloat: @@ -165,16 +165,7 @@ func (n *N) Unpack(ls *lua.LState, buf []byte) lua.LValue { panic("invalid numeric metatype") } -type ( - NI64 struct{ val int64 } - NU64 struct{ val uint64 } -) - -var ( - LuaN = LuaUserDataType[N]{Name: "flatbuffers_n"} - LuaNI64 = LuaUserDataType[NI64]{Name: "flatbuffers_i64"} - LuaNU64 = LuaUserDataType[NU64]{Name: "flatbuffers_u64"} -) +var LuaN = LuaUserDataType[N]{Name: "flatbuffers_n"} func FlatBuffers_N_Loader(ls *lua.LState) int { m := ls.NewTable() diff --git a/flow/pua/peerdb.go b/flow/pua/peerdb.go index ebbd1e8e5c..9f0c852516 100644 --- a/flow/pua/peerdb.go +++ b/flow/pua/peerdb.go @@ -21,6 +21,8 @@ var ( LuaRecord = LuaUserDataType[model.Record]{Name: "peerdb_record"} LuaRow = LuaUserDataType[*model.RecordItems]{Name: "peerdb_row"} LuaQValue = LuaUserDataType[qvalue.QValue]{Name: "peerdb_value"} + LuaI64 = LuaUserDataType[int64]{Name: "flatbuffers_i64"} + LuaU64 = LuaUserDataType[uint64]{Name: "flatbuffers_u64"} LuaTime = LuaUserDataType[time.Time]{Name: "peerdb_time"} LuaUuid = LuaUserDataType[uuid.UUID]{Name: "peerdb_uuid"} LuaBigInt = LuaUserDataType[*big.Int]{Name: "peerdb_bigint"} @@ -49,6 +51,18 @@ func RegisterTypes(ls *lua.LState) { mt.RawSetString("__index", ls.NewFunction(LuaUuidIndex)) mt.RawSetString("__string", ls.NewFunction(LuaUuidString)) + mt = LuaI64.NewMetatable(ls) + mt.RawSetString("__index", ls.NewFunction(LuaI64Index)) + mt.RawSetString("__eq", ls.NewFunction(Lua64Eq)) + mt.RawSetString("__le", ls.NewFunction(Lua64Le)) + mt.RawSetString("__lt", ls.NewFunction(Lua64Lt)) + + mt = LuaU64.NewMetatable(ls) + mt.RawSetString("__index", ls.NewFunction(LuaU64Index)) + mt.RawSetString("__eq", ls.NewFunction(Lua64Eq)) + mt.RawSetString("__le", ls.NewFunction(Lua64Le)) + mt.RawSetString("__lt", ls.NewFunction(Lua64Lt)) + mt = LuaTime.NewMetatable(ls) mt.RawSetString("__index", ls.NewFunction(LuaTimeIndex)) @@ -183,7 +197,7 @@ func LuaRecordIndex(ls *lua.LState) int { ls.Push(lua.LNil) } case "checkpoint": - ls.Push(LuaNI64.New(ls, NI64{record.GetCheckpointID()})) + ls.Push(LuaI64.New(ls, record.GetCheckpointID())) case "target": ls.Push(lua.LString(record.GetDestinationTableName())) case "source": @@ -224,7 +238,7 @@ func LuaQValueIndex(ls *lua.LState) int { case int32: ls.Push(lua.LNumber(v)) case int64: - ls.Push(LuaNI64.New(ls, NI64{v})) + ls.Push(LuaI64.New(ls, v)) case float32: ls.Push(lua.LNumber(v)) case float64: @@ -266,7 +280,7 @@ func LuaQValueIndex(ls *lua.LState) int { }) case []int64: qvToLTable(ls, v, func(x int64) lua.LValue { - return LuaNI64.New(ls, NI64{x}) + return LuaI64.New(ls, x) }) case []string: qvToLTable(ls, v, func(x string) lua.LValue { @@ -282,7 +296,7 @@ func LuaQValueIndex(ls *lua.LState) int { }) } case "int64": - ls.Push(LuaNI64.New(ls, NI64{reflect.ValueOf(qv.Value).Int()})) + ls.Push(LuaI64.New(ls, reflect.ValueOf(qv.Value).Int())) case "float64": ls.Push(lua.LNumber(reflect.ValueOf(qv.Value).Float())) default: @@ -338,17 +352,157 @@ func LuaUUID(ls *lua.LState) int { return 1 } +func Lua64Eq(ls *lua.LState) int { + aud := ls.CheckUserData(1) + bud := ls.CheckUserData(2) + switch a := aud.Value.(type) { + case int64: + switch b := bud.Value.(type) { + case int64: + ls.Push(lua.LBool(a == b)) + case uint64: + if a < 0 { + ls.Push(lua.LFalse) + } else { + ls.Push(lua.LBool(uint64(a) == b)) + } + default: + return 0 + } + case uint64: + switch b := bud.Value.(type) { + case int64: + if b < 0 { + ls.Push(lua.LFalse) + } else { + ls.Push(lua.LBool(a == uint64(b))) + } + case uint64: + ls.Push(lua.LBool(a == b)) + default: + return 0 + } + default: + return 0 + } + return 1 +} + +func Lua64Le(ls *lua.LState) int { + aud := ls.CheckUserData(1) + bud := ls.CheckUserData(2) + switch a := aud.Value.(type) { + case int64: + switch b := bud.Value.(type) { + case int64: + ls.Push(lua.LBool(a <= b)) + case uint64: + if a < 0 { + ls.Push(lua.LTrue) + } else { + ls.Push(lua.LBool(uint64(a) <= b)) + } + default: + return 0 + } + case uint64: + switch b := bud.Value.(type) { + case int64: + if b < 0 { + ls.Push(lua.LFalse) + } else { + ls.Push(lua.LBool(a <= uint64(b))) + } + case uint64: + ls.Push(lua.LBool(a <= b)) + default: + return 0 + } + default: + return 0 + } + return 1 +} + +func Lua64Lt(ls *lua.LState) int { + aud := ls.CheckUserData(1) + bud := ls.CheckUserData(2) + switch a := aud.Value.(type) { + case int64: + switch b := bud.Value.(type) { + case int64: + ls.Push(lua.LBool(a < b)) + case uint64: + if a < 0 { + ls.Push(lua.LTrue) + } else { + ls.Push(lua.LBool(uint64(a) < b)) + } + default: + return 0 + } + case uint64: + switch b := bud.Value.(type) { + case int64: + if b < 0 { + ls.Push(lua.LTrue) + } else { + ls.Push(lua.LBool(a < uint64(b))) + } + case uint64: + ls.Push(lua.LBool(a < b)) + default: + return 0 + } + default: + return 0 + } + return 1 +} + +func LuaI64Index(ls *lua.LState) int { + i64ud, i64 := LuaI64.Check(ls, 1) + key := ls.CheckString(2) + switch key { + case "i64": + ls.Push(i64ud) + case "u64": + ls.Push(LuaU64.New(ls, uint64(i64))) + case "float64": + ls.Push(lua.LNumber(i64)) + default: + return 0 + } + return 1 +} + +func LuaU64Index(ls *lua.LState) int { + u64ud, u64 := LuaU64.Check(ls, 1) + key := ls.CheckString(2) + switch key { + case "i64": + ls.Push(LuaI64.New(ls, int64(u64))) + case "u64": + ls.Push(u64ud) + case "float64": + ls.Push(lua.LNumber(u64)) + default: + return 0 + } + return 1 +} + func LuaTimeIndex(ls *lua.LState) int { tm, key := LuaTime.StartIndex(ls) switch key { case "unix_nano": - ls.Push(LuaNI64.New(ls, NI64{tm.UnixNano()})) + ls.Push(LuaI64.New(ls, tm.UnixNano())) case "unix_micro": - ls.Push(LuaNI64.New(ls, NI64{tm.UnixMicro()})) + ls.Push(LuaI64.New(ls, tm.UnixMicro())) case "unix_milli": - ls.Push(LuaNI64.New(ls, NI64{tm.UnixMilli()})) + ls.Push(LuaI64.New(ls, tm.UnixMilli())) case "unix": - ls.Push(LuaNI64.New(ls, NI64{tm.Unix()})) + ls.Push(LuaI64.New(ls, tm.Unix())) case "year": ls.Push(lua.LNumber(tm.Year())) case "month":