diff --git a/flow/pua/peerdb.go b/flow/pua/peerdb.go index 540065970f..8e3221aba1 100644 --- a/flow/pua/peerdb.go +++ b/flow/pua/peerdb.go @@ -51,6 +51,7 @@ func RegisterTypes(ls *lua.LState) { mt = LuaRow.NewMetatable(ls) mt.RawSetString("__index", ls.NewFunction(LuaRowIndex)) + mt.RawSetString("__newindex", ls.NewFunction(LuaRowNewIndex)) mt.RawSetString("__len", ls.NewFunction(LuaRowLen)) mt = shared.LuaUuid.NewMetatable(ls) @@ -136,6 +137,160 @@ func LuaRowIndex(ls *lua.LState) int { return 1 } +func LuaRowNewIndex(ls *lua.LState) int { + _, row := LuaRow.Check(ls, 1) + key := ls.CheckString(2) + val := ls.Get(3) + qv := row.GetColumnValue(key) + kind := qv.Kind() + if val == lua.LNil { + row.AddColumn(key, qvalue.QValueNull(kind)) + } + var newqv qvalue.QValue + switch qv.Kind() { + case qvalue.QValueKindInvalid: + newqv = qvalue.QValueInvalid{Val: lua.LVAsString(val)} + case qvalue.QValueKindFloat32: + newqv = qvalue.QValueFloat32{Val: float32(lua.LVAsNumber(val))} + case qvalue.QValueKindFloat64: + newqv = qvalue.QValueFloat64{Val: float64(lua.LVAsNumber(val))} + case qvalue.QValueKindInt16: + newqv = qvalue.QValueInt16{Val: int16(lua.LVAsNumber(val))} + case qvalue.QValueKindInt32: + newqv = qvalue.QValueInt32{Val: int32(lua.LVAsNumber(val))} + case qvalue.QValueKindInt64: + switch v := val.(type) { + case lua.LNumber: + newqv = qvalue.QValueInt64{Val: int64(v)} + case *lua.LUserData: + switch i64 := v.Value.(type) { + case int64: + newqv = qvalue.QValueInt64{Val: i64} + case uint64: + newqv = qvalue.QValueInt64{Val: int64(i64)} + } + } + if newqv == nil { + ls.RaiseError("invalid int64") + } + case qvalue.QValueKindBoolean: + newqv = qvalue.QValueBoolean{Val: lua.LVAsBool(val)} + case qvalue.QValueKindQChar: + switch v := val.(type) { + case lua.LNumber: + newqv = qvalue.QValueQChar{Val: uint8(v)} + case lua.LString: + if len(v) > 0 { + newqv = qvalue.QValueQChar{Val: v[0]} + } + default: + ls.RaiseError("invalid \"char\"") + } + case qvalue.QValueKindString: + newqv = qvalue.QValueString{Val: lua.LVAsString(val)} + /* TODO time + case qvalue.QValueKindTimestamp: + newqv = qvalue.QValueTimestamp{Val:} + case qvalue.QValueKindTimestampTZ: + newqv = qvalue.QValueTimestampTZ{Val:} + case qvalue.QValueKindDate: + newqv = qvalue.QValueDate{Val:} + case qvalue.QValueKindTime: + newqv = qvalue.QValueTime{Val:} + case qvalue.QValueKindTimeTZ: + newqv = qvalue.QValueTimeTZ{Val:} + */ + case qvalue.QValueKindNumeric: + if ud, ok := val.(*lua.LUserData); ok { + if num, ok := ud.Value.(decimal.Decimal); ok { + newqv = qvalue.QValueNumeric{Val: num} + } + } + case qvalue.QValueKindBytes: + newqv = qvalue.QValueBytes{Val: []byte(lua.LVAsString(val))} + case qvalue.QValueKindUUID: + if ud, ok := val.(*lua.LUserData); ok { + if id, ok := ud.Value.(uuid.UUID); ok { + newqv = qvalue.QValueUUID{Val: [16]byte(id)} + } + } + case qvalue.QValueKindJSON: + newqv = qvalue.QValueJSON{Val: lua.LVAsString(val)} + case qvalue.QValueKindBit: + newqv = qvalue.QValueBit{Val: []byte(lua.LVAsString(val))} + case qvalue.QValueKindArrayFloat32: + if tbl, ok := val.(*lua.LTable); ok { + newqv = qvalue.QValueArrayFloat32{ + Val: shared.LTableToSlice(ls, tbl, func(_ *lua.LState, v lua.LValue) float32 { + return float32(lua.LVAsNumber(v)) + }), + } + } + case qvalue.QValueKindArrayFloat64: + if tbl, ok := val.(*lua.LTable); ok { + newqv = qvalue.QValueArrayFloat64{ + Val: shared.LTableToSlice(ls, tbl, func(_ *lua.LState, v lua.LValue) float64 { + return float64(lua.LVAsNumber(v)) + }), + } + } + case qvalue.QValueKindArrayInt16: + if tbl, ok := val.(*lua.LTable); ok { + newqv = qvalue.QValueArrayFloat64{ + Val: shared.LTableToSlice(ls, tbl, func(_ *lua.LState, v lua.LValue) float64 { + return float64(lua.LVAsNumber(v)) + }), + } + } + case qvalue.QValueKindArrayInt32: + if tbl, ok := val.(*lua.LTable); ok { + newqv = qvalue.QValueArrayFloat64{ + Val: shared.LTableToSlice(ls, tbl, func(_ *lua.LState, v lua.LValue) float64 { + return float64(lua.LVAsNumber(v)) + }), + } + } + case qvalue.QValueKindArrayInt64: + if tbl, ok := val.(*lua.LTable); ok { + newqv = qvalue.QValueArrayFloat64{ + Val: shared.LTableToSlice(ls, tbl, func(_ *lua.LState, v lua.LValue) float64 { + return float64(lua.LVAsNumber(v)) + }), + } + } + case qvalue.QValueKindArrayString: + if tbl, ok := val.(*lua.LTable); ok { + newqv = qvalue.QValueArrayString{ + Val: shared.LTableToSlice(ls, tbl, func(_ *lua.LState, v lua.LValue) string { + return lua.LVAsString(v) + }), + } + } + /* TODO TIME + case qvalue.QValueKindArrayDate: + newqv = qvalue.QValueArrayDate{Val:} + case qvalue.QValueKindArrayTimestamp: + newqv = qvalue.QValueArrayTimestamp{Val:} + case qvalue.QValueKindArrayTimestampTZ: + newqv = qvalue.QValueArrayTimestampTZ{Val:} + */ + case qvalue.QValueKindArrayBoolean: + if tbl, ok := val.(*lua.LTable); ok { + newqv = qvalue.QValueArrayBoolean{ + Val: shared.LTableToSlice(ls, tbl, func(_ *lua.LState, v lua.LValue) bool { + return lua.LVAsBool(v) + }), + } + } + default: + ls.RaiseError(fmt.Sprintf("no support for reassigning %s", kind)) + return 0 + } + + row.AddColumn(key, newqv) + return 1 +} + func LuaRowLen(ls *lua.LState) int { row := LuaRow.StartMethod(ls) ls.Push(lua.LNumber(len(row.ColToVal))) diff --git a/flow/shared/lua.go b/flow/shared/lua.go index 2b95c3464c..26aeb4fe2a 100644 --- a/flow/shared/lua.go +++ b/flow/shared/lua.go @@ -18,7 +18,7 @@ var ( LuaDecimal = glua64.UserDataType[decimal.Decimal]{Name: "peerdb_decimal"} ) -func SliceToLTable[T any](ls *lua.LState, s []T, f func(x T) lua.LValue) *lua.LTable { +func SliceToLTable[T any](ls *lua.LState, s []T, f func(T) lua.LValue) *lua.LTable { tbl := ls.CreateTable(len(s), 0) tbl.Metatable = ls.GetTypeMetatable("Array") for idx, val := range s { @@ -26,3 +26,12 @@ func SliceToLTable[T any](ls *lua.LState, s []T, f func(x T) lua.LValue) *lua.LT } return tbl } + +func LTableToSlice[T any](ls *lua.LState, tbl *lua.LTable, f func(*lua.LState, lua.LValue) T) []T { + tlen := tbl.Len() + slice := make([]T, 0, tlen) + for i := range tlen { + slice = append(slice, f(ls, tbl.RawGetInt(i))) + } + return slice +}