Skip to content

Commit

Permalink
Don't expose QValue directly to scripts
Browse files Browse the repository at this point in the history
2 reasons:
1. I'd like to eventually replace QValue with interface{}
2. Avoids having to sprinkle `.value` throughout script
  • Loading branch information
serprex committed Mar 17, 2024
1 parent cfb9579 commit 85ad7a8
Showing 1 changed file with 95 additions and 116 deletions.
211 changes: 95 additions & 116 deletions flow/pua/peerdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ import (
"bytes"
"fmt"
"math/big"
"reflect"
"strconv"
"strings"
"time"

"github.com/google/uuid"
Expand All @@ -22,7 +20,6 @@ import (
var (
LuaRecord = LuaUserDataType[model.Record]{Name: "peerdb_record"}
LuaRow = LuaUserDataType[*model.RecordItems]{Name: "peerdb_row"}
LuaQValue = LuaUserDataType[qvalue.QValue]{Name: "peerdb_value"}
LuaI64 = LuaUserDataType[int64]{Name: "flatbuffers_i64"}
LuaU64 = LuaUserDataType[uint64]{Name: "flatbuffers_u64"}
LuaTime = LuaUserDataType[time.Time]{Name: "peerdb_time"}
Expand All @@ -44,11 +41,6 @@ func RegisterTypes(ls *lua.LState) {
mt.RawSetString("__index", ls.NewFunction(LuaRowIndex))
mt.RawSetString("__len", ls.NewFunction(LuaRowLen))

mt = LuaQValue.NewMetatable(ls)
mt.RawSetString("__index", ls.NewFunction(LuaQValueIndex))
mt.RawSetString("__tostring", ls.NewFunction(LuaQValueString))
mt.RawSetString("__len", ls.NewFunction(LuaQValueLen))

mt = LuaUuid.NewMetatable(ls)
mt.RawSetString("__index", ls.NewFunction(LuaUuidIndex))
mt.RawSetString("__tostring", ls.NewFunction(LuaUuidString))
Expand Down Expand Up @@ -83,9 +75,11 @@ func RegisterTypes(ls *lua.LState) {
peerdb := ls.NewTable()
peerdb.RawSetString("RowToJSON", ls.NewFunction(LuaRowToJSON))
peerdb.RawSetString("RowColumns", ls.NewFunction(LuaRowColumns))
peerdb.RawSetString("RowColumnKind", ls.NewFunction(LuaRowColumnKind))
peerdb.RawSetString("Now", ls.NewFunction(LuaNow))
peerdb.RawSetString("UUID", ls.NewFunction(LuaUUID))
peerdb.RawSetString("type", ls.NewFunction(LuaType))
peerdb.RawSetString("tostring", ls.NewFunction(LuaToString))
ls.Env.RawSetString("peerdb", peerdb)
}

Expand Down Expand Up @@ -117,16 +111,20 @@ func LoadPeerdbScript(ls *lua.LState) int {
return 1
}

func LuaRowIndex(ls *lua.LState) int {
func GetRowQ(ls *lua.LState, row *model.RecordItems, col string) qvalue.QValue {

Check failure on line 114 in flow/pua/peerdb.go

View workflow job for this annotation

GitHub Actions / lint

SA4009: argument row is overwritten before first use (staticcheck)
row, key := LuaRow.StartIndex(ls)

Check failure on line 115 in flow/pua/peerdb.go

View workflow job for this annotation

GitHub Actions / lint

SA4009(related information): assignment to row (staticcheck)

qv, err := row.GetValueByColName(key)
if err != nil {
ls.RaiseError(err.Error())
return 0
return qvalue.QValue{}
}
return qv
}

ls.Push(LuaQValue.New(ls, qv))
func LuaRowIndex(ls *lua.LState) int {
row, key := LuaRow.StartIndex(ls)
ls.Push(LuaQValue(ls, GetRowQ(ls, row, key)))
return 1
}

Expand Down Expand Up @@ -157,6 +155,12 @@ func LuaRowColumns(ls *lua.LState) int {
return 1
}

func LuaRowColumnKind(ls *lua.LState) int {
row, key := LuaRow.StartIndex(ls)
ls.Push(lua.LString(GetRowQ(ls, row, key).Kind))
return 1
}

func LuaRecordIndex(ls *lua.LState) int {
record, key := LuaRecord.StartIndex(ls)
switch key {
Expand Down Expand Up @@ -216,121 +220,87 @@ func LuaRecordIndex(ls *lua.LState) int {
return 1
}

func qvToLTable[T any](ls *lua.LState, s []T, f func(x T) lua.LValue) {
func qvToLTable[T any](ls *lua.LState, s []T, f func(x T) lua.LValue) *lua.LTable {
tbl := ls.CreateTable(len(s), 0)
for idx, val := range s {
tbl.RawSetInt(idx, f(val))
}
ls.Push(tbl)
return tbl
}

func LuaQValueIndex(ls *lua.LState) int {
qv, key := LuaQValue.StartIndex(ls)
switch key {
case "kind":
ls.Push(lua.LString(qv.Kind))
case "value":
switch v := qv.Value.(type) {
case nil:
ls.Push(lua.LNil)
case bool:
ls.Push(lua.LBool(v))
case uint8:
if qv.Kind == qvalue.QValueKindQChar {
ls.Push(lua.LString(rune(v)))
} else {
ls.Push(lua.LNumber(v))
}
case int16:
ls.Push(lua.LNumber(v))
case int32:
ls.Push(lua.LNumber(v))
case int64:
ls.Push(LuaI64.New(ls, v))
case float32:
ls.Push(lua.LNumber(v))
case float64:
ls.Push(lua.LNumber(v))
case string:
if qv.Kind == qvalue.QValueKindUUID {
u, err := uuid.Parse(v)
if err != nil {
ls.Push(LuaUuid.New(ls, u))
} else {
ls.Push(lua.LString(v))
}
} else {
ls.Push(lua.LString(v))
func LuaQValue(ls *lua.LState, qv qvalue.QValue) lua.LValue {
switch v := qv.Value.(type) {
case nil:
return lua.LNil
case bool:
return lua.LBool(v)
case uint8:
if qv.Kind == qvalue.QValueKindQChar {
return lua.LString(rune(v))
} else {
return lua.LNumber(v)
}
case int16:
return lua.LNumber(v)
case int32:
return lua.LNumber(v)
case int64:
return LuaI64.New(ls, v)
case float32:
return lua.LNumber(v)
case float64:
return lua.LNumber(v)
case string:
if qv.Kind == qvalue.QValueKindUUID {
u, err := uuid.Parse(v)
if err != nil {
return LuaUuid.New(ls, u)
}
case time.Time:
ls.Push(LuaTime.New(ls, v))
case decimal.Decimal:
ls.Push(LuaDecimal.New(ls, v))
case [16]byte:
ls.Push(LuaUuid.New(ls, uuid.UUID(v)))
case []byte:
ls.Push(lua.LString(v))
case []float32:
qvToLTable(ls, v, func(f float32) lua.LValue {
return lua.LNumber(f)
})
case []float64:
qvToLTable(ls, v, func(f float64) lua.LValue {
return lua.LNumber(f)
})
case []int16:
qvToLTable(ls, v, func(f int16) lua.LValue {
return lua.LNumber(f)
})
case []int32:
qvToLTable(ls, v, func(f int32) lua.LValue {
return lua.LNumber(f)
})
case []int64:
qvToLTable(ls, v, func(x int64) lua.LValue {
return LuaI64.New(ls, x)
})
case []string:
qvToLTable(ls, v, func(x string) lua.LValue {
return lua.LString(x)
})
case []time.Time:
qvToLTable(ls, v, func(x time.Time) lua.LValue {
return LuaTime.New(ls, x)
})
case []bool:
qvToLTable(ls, v, func(x bool) lua.LValue {
return lua.LBool(x)
})
}
case "int64":
ls.Push(LuaI64.New(ls, reflect.ValueOf(qv.Value).Int()))
case "float64":
ls.Push(lua.LNumber(reflect.ValueOf(qv.Value).Float()))
return lua.LString(v)
case time.Time:
return LuaTime.New(ls, v)
case decimal.Decimal:
return LuaDecimal.New(ls, v)
case [16]byte:
return LuaUuid.New(ls, uuid.UUID(v))
case []byte:
return lua.LString(v)
case []float32:
return qvToLTable(ls, v, func(f float32) lua.LValue {
return lua.LNumber(f)
})
case []float64:
return qvToLTable(ls, v, func(f float64) lua.LValue {
return lua.LNumber(f)
})
case []int16:
return qvToLTable(ls, v, func(x int16) lua.LValue {
return lua.LNumber(x)
})
case []int32:
return qvToLTable(ls, v, func(x int32) lua.LValue {
return lua.LNumber(x)
})
case []int64:
return qvToLTable(ls, v, func(x int64) lua.LValue {
return LuaI64.New(ls, x)
})
case []string:
return qvToLTable(ls, v, func(x string) lua.LValue {
return lua.LString(x)
})
case []time.Time:
return qvToLTable(ls, v, func(x time.Time) lua.LValue {
return LuaTime.New(ls, x)
})
case []bool:
return qvToLTable(ls, v, func(x bool) lua.LValue {
return lua.LBool(x)
})
default:
return 0
return lua.LString(fmt.Sprint(qv.Value))
}
return 1
}

func LuaQValueLen(ls *lua.LState) int {
qv := LuaQValue.StartMeta(ls)
str, ok := qv.Value.(string)
if ok {
ls.Push(lua.LNumber(len(str)))
return 1
}
if strings.HasPrefix(string(qv.Kind), "array_") {
ls.Push(lua.LNumber(reflect.ValueOf(qv.Value).Len()))
return 1
}
return 0
}

func LuaQValueString(ls *lua.LState) int {
qv := LuaQValue.StartMeta(ls)
ls.Push(lua.LString(fmt.Sprint(qv.Value)))
return 1
}

func LuaUuidIndex(ls *lua.LState) int {
Expand Down Expand Up @@ -369,6 +339,15 @@ func LuaType(ls *lua.LState) int {
return 0
}

func LuaToString(ls *lua.LState) int {
val := ls.Get(1)
if ud, ok := val.(*lua.LUserData); ok {
ls.Push(lua.LString(fmt.Sprint(ud.Value)))
return 1
}
return 0
}

func Lua64Eq(ls *lua.LState) int {
aud := ls.CheckUserData(1)
bud := ls.CheckUserData(2)
Expand Down

0 comments on commit 85ad7a8

Please sign in to comment.