diff --git a/flow/pua/peerdb.go b/flow/pua/peerdb.go index 86defca9fa..bb633af864 100644 --- a/flow/pua/peerdb.go +++ b/flow/pua/peerdb.go @@ -4,9 +4,7 @@ import ( "bytes" "fmt" "math/big" - "reflect" "strconv" - "strings" "time" "github.com/google/uuid" @@ -22,7 +20,6 @@ import ( var ( LuaRecord = LuaUserDataType[model.Record]{Name: "peerdb_record"} LuaRow = LuaUserDataType[*model.RecordItems]{Name: "peerdb_row"} - LuaQValue = LuaUserDataType[qvalue.QValue]{Name: "peerdb_value"} LuaI64 = LuaUserDataType[int64]{Name: "flatbuffers_i64"} LuaU64 = LuaUserDataType[uint64]{Name: "flatbuffers_u64"} LuaTime = LuaUserDataType[time.Time]{Name: "peerdb_time"} @@ -44,11 +41,6 @@ func RegisterTypes(ls *lua.LState) { mt.RawSetString("__index", ls.NewFunction(LuaRowIndex)) mt.RawSetString("__len", ls.NewFunction(LuaRowLen)) - mt = LuaQValue.NewMetatable(ls) - mt.RawSetString("__index", ls.NewFunction(LuaQValueIndex)) - mt.RawSetString("__tostring", ls.NewFunction(LuaQValueString)) - mt.RawSetString("__len", ls.NewFunction(LuaQValueLen)) - mt = LuaUuid.NewMetatable(ls) mt.RawSetString("__index", ls.NewFunction(LuaUuidIndex)) mt.RawSetString("__tostring", ls.NewFunction(LuaUuidString)) @@ -83,9 +75,11 @@ func RegisterTypes(ls *lua.LState) { peerdb := ls.NewTable() peerdb.RawSetString("RowToJSON", ls.NewFunction(LuaRowToJSON)) peerdb.RawSetString("RowColumns", ls.NewFunction(LuaRowColumns)) + peerdb.RawSetString("RowColumnKind", ls.NewFunction(LuaRowColumnKind)) peerdb.RawSetString("Now", ls.NewFunction(LuaNow)) peerdb.RawSetString("UUID", ls.NewFunction(LuaUUID)) peerdb.RawSetString("type", ls.NewFunction(LuaType)) + peerdb.RawSetString("tostring", ls.NewFunction(LuaToString)) ls.Env.RawSetString("peerdb", peerdb) } @@ -117,16 +111,20 @@ func LoadPeerdbScript(ls *lua.LState) int { return 1 } -func LuaRowIndex(ls *lua.LState) int { +func GetRowQ(ls *lua.LState, row *model.RecordItems, col string) qvalue.QValue { row, key := LuaRow.StartIndex(ls) qv, err := row.GetValueByColName(key) if err != nil { ls.RaiseError(err.Error()) - return 0 + return qvalue.QValue{} } + return qv +} - ls.Push(LuaQValue.New(ls, qv)) +func LuaRowIndex(ls *lua.LState) int { + row, key := LuaRow.StartIndex(ls) + ls.Push(LuaQValue(ls, GetRowQ(ls, row, key))) return 1 } @@ -157,6 +155,12 @@ func LuaRowColumns(ls *lua.LState) int { return 1 } +func LuaRowColumnKind(ls *lua.LState) int { + row, key := LuaRow.StartIndex(ls) + ls.Push(lua.LString(GetRowQ(ls, row, key).Kind)) + return 1 +} + func LuaRecordIndex(ls *lua.LState) int { record, key := LuaRecord.StartIndex(ls) switch key { @@ -216,121 +220,87 @@ func LuaRecordIndex(ls *lua.LState) int { return 1 } -func qvToLTable[T any](ls *lua.LState, s []T, f func(x T) lua.LValue) { +func qvToLTable[T any](ls *lua.LState, s []T, f func(x T) lua.LValue) *lua.LTable { tbl := ls.CreateTable(len(s), 0) for idx, val := range s { tbl.RawSetInt(idx, f(val)) } - ls.Push(tbl) + return tbl } -func LuaQValueIndex(ls *lua.LState) int { - qv, key := LuaQValue.StartIndex(ls) - switch key { - case "kind": - ls.Push(lua.LString(qv.Kind)) - case "value": - switch v := qv.Value.(type) { - case nil: - ls.Push(lua.LNil) - case bool: - ls.Push(lua.LBool(v)) - case uint8: - if qv.Kind == qvalue.QValueKindQChar { - ls.Push(lua.LString(rune(v))) - } else { - ls.Push(lua.LNumber(v)) - } - case int16: - ls.Push(lua.LNumber(v)) - case int32: - ls.Push(lua.LNumber(v)) - case int64: - ls.Push(LuaI64.New(ls, v)) - case float32: - ls.Push(lua.LNumber(v)) - case float64: - ls.Push(lua.LNumber(v)) - case string: - if qv.Kind == qvalue.QValueKindUUID { - u, err := uuid.Parse(v) - if err != nil { - ls.Push(LuaUuid.New(ls, u)) - } else { - ls.Push(lua.LString(v)) - } - } else { - ls.Push(lua.LString(v)) +func LuaQValue(ls *lua.LState, qv qvalue.QValue) lua.LValue { + switch v := qv.Value.(type) { + case nil: + return lua.LNil + case bool: + return lua.LBool(v) + case uint8: + if qv.Kind == qvalue.QValueKindQChar { + return lua.LString(rune(v)) + } else { + return lua.LNumber(v) + } + case int16: + return lua.LNumber(v) + case int32: + return lua.LNumber(v) + case int64: + return LuaI64.New(ls, v) + case float32: + return lua.LNumber(v) + case float64: + return lua.LNumber(v) + case string: + if qv.Kind == qvalue.QValueKindUUID { + u, err := uuid.Parse(v) + if err != nil { + return LuaUuid.New(ls, u) } - case time.Time: - ls.Push(LuaTime.New(ls, v)) - case decimal.Decimal: - ls.Push(LuaDecimal.New(ls, v)) - case [16]byte: - ls.Push(LuaUuid.New(ls, uuid.UUID(v))) - case []byte: - ls.Push(lua.LString(v)) - case []float32: - qvToLTable(ls, v, func(f float32) lua.LValue { - return lua.LNumber(f) - }) - case []float64: - qvToLTable(ls, v, func(f float64) lua.LValue { - return lua.LNumber(f) - }) - case []int16: - qvToLTable(ls, v, func(f int16) lua.LValue { - return lua.LNumber(f) - }) - case []int32: - qvToLTable(ls, v, func(f int32) lua.LValue { - return lua.LNumber(f) - }) - case []int64: - qvToLTable(ls, v, func(x int64) lua.LValue { - return LuaI64.New(ls, x) - }) - case []string: - qvToLTable(ls, v, func(x string) lua.LValue { - return lua.LString(x) - }) - case []time.Time: - qvToLTable(ls, v, func(x time.Time) lua.LValue { - return LuaTime.New(ls, x) - }) - case []bool: - qvToLTable(ls, v, func(x bool) lua.LValue { - return lua.LBool(x) - }) } - case "int64": - ls.Push(LuaI64.New(ls, reflect.ValueOf(qv.Value).Int())) - case "float64": - ls.Push(lua.LNumber(reflect.ValueOf(qv.Value).Float())) + return lua.LString(v) + case time.Time: + return LuaTime.New(ls, v) + case decimal.Decimal: + return LuaDecimal.New(ls, v) + case [16]byte: + return LuaUuid.New(ls, uuid.UUID(v)) + case []byte: + return lua.LString(v) + case []float32: + return qvToLTable(ls, v, func(f float32) lua.LValue { + return lua.LNumber(f) + }) + case []float64: + return qvToLTable(ls, v, func(f float64) lua.LValue { + return lua.LNumber(f) + }) + case []int16: + return qvToLTable(ls, v, func(x int16) lua.LValue { + return lua.LNumber(x) + }) + case []int32: + return qvToLTable(ls, v, func(x int32) lua.LValue { + return lua.LNumber(x) + }) + case []int64: + return qvToLTable(ls, v, func(x int64) lua.LValue { + return LuaI64.New(ls, x) + }) + case []string: + return qvToLTable(ls, v, func(x string) lua.LValue { + return lua.LString(x) + }) + case []time.Time: + return qvToLTable(ls, v, func(x time.Time) lua.LValue { + return LuaTime.New(ls, x) + }) + case []bool: + return qvToLTable(ls, v, func(x bool) lua.LValue { + return lua.LBool(x) + }) default: - return 0 + return lua.LString(fmt.Sprint(qv.Value)) } - return 1 -} - -func LuaQValueLen(ls *lua.LState) int { - qv := LuaQValue.StartMeta(ls) - str, ok := qv.Value.(string) - if ok { - ls.Push(lua.LNumber(len(str))) - return 1 - } - if strings.HasPrefix(string(qv.Kind), "array_") { - ls.Push(lua.LNumber(reflect.ValueOf(qv.Value).Len())) - return 1 - } - return 0 -} - -func LuaQValueString(ls *lua.LState) int { - qv := LuaQValue.StartMeta(ls) - ls.Push(lua.LString(fmt.Sprint(qv.Value))) - return 1 } func LuaUuidIndex(ls *lua.LState) int { @@ -369,6 +339,15 @@ func LuaType(ls *lua.LState) int { return 0 } +func LuaToString(ls *lua.LState) int { + val := ls.Get(1) + if ud, ok := val.(*lua.LUserData); ok { + ls.Push(lua.LString(fmt.Sprint(ud.Value))) + return 1 + } + return 0 +} + func Lua64Eq(ls *lua.LState) int { aud := ls.CheckUserData(1) bud := ls.CheckUserData(2)