diff --git a/arrow/json/writer.go b/arrow/json/writer.go index 381d0b4..25c9ce1 100644 --- a/arrow/json/writer.go +++ b/arrow/json/writer.go @@ -3,14 +3,17 @@ package json import ( "encoding/json" "errors" + "fmt" "io" + "strings" "github.com/apache/arrow/go/arrow" "github.com/apache/arrow/go/arrow/array" ) var ( - ErrMismatchFields = errors.New("arrow/json: number of records mismatch") + ErrMismatchFields = errors.New("arrow/json: number of records mismatch") + ErrUnsupportedType = errors.New("arrow/json: unsupported type") ) // JsonEncoder wraps encoding/json.Encoder and writes array.Record based on a schema. @@ -25,9 +28,6 @@ type Encoder struct { // NewWriter panics if the given schema contains fields that have types that are not // primitive types. func NewWriter(w io.Writer, schema *arrow.Schema) *Encoder { - // TODO - // validate(schema) - ww := &Encoder{ e: json.NewEncoder(w), schema: schema, @@ -49,121 +49,275 @@ func (e *Encoder) Write(record array.Record) error { recs[i] = make(map[string]interface{}, record.NumCols()) } - for j, col := range record.Columns() { - field := e.schema.Field(j) - switch field.Type.(type) { - case *arrow.BooleanType: - arr := col.(*array.Boolean) - for i := 0; i < arr.Len(); i++ { - if arr.IsValid(i) { - recs[i][field.Name] = arr.Value(i) - } else { - recs[i][field.Name] = nil - } - } - case *arrow.Int8Type: - arr := col.(*array.Int8) - for i := 0; i < arr.Len(); i++ { - if arr.IsValid(i) { - recs[i][field.Name] = arr.Value(i) - } else { - recs[i][field.Name] = nil - } - } - case *arrow.Int16Type: - arr := col.(*array.Int16) - for i := 0; i < arr.Len(); i++ { - if arr.IsValid(i) { - recs[i][field.Name] = arr.Value(i) - } else { - recs[i][field.Name] = nil - } - } - case *arrow.Int32Type: - arr := col.(*array.Int32) - for i := 0; i < arr.Len(); i++ { - if arr.IsValid(i) { - recs[i][field.Name] = arr.Value(i) - } else { - recs[i][field.Name] = nil - } - } - case *arrow.Int64Type: - arr := col.(*array.Int64) - for i := 0; i < arr.Len(); i++ { - if arr.IsValid(i) { - recs[i][field.Name] = arr.Value(i) - } else { - recs[i][field.Name] = nil - } - } - case *arrow.Uint8Type: - arr := col.(*array.Uint8) - for i := 0; i < arr.Len(); i++ { - if arr.IsValid(i) { - recs[i][field.Name] = arr.Value(i) - } else { - recs[i][field.Name] = nil - } - } - case *arrow.Uint16Type: - arr := col.(*array.Uint16) - for i := 0; i < arr.Len(); i++ { - if arr.IsValid(i) { - recs[i][field.Name] = arr.Value(i) - } else { - recs[i][field.Name] = nil - } - } - case *arrow.Uint32Type: - arr := col.(*array.Uint32) - for i := 0; i < arr.Len(); i++ { - if arr.IsValid(i) { - recs[i][field.Name] = arr.Value(i) - } else { - recs[i][field.Name] = nil - } - } - case *arrow.Uint64Type: - arr := col.(*array.Uint64) - for i := 0; i < arr.Len(); i++ { - if arr.IsValid(i) { - recs[i][field.Name] = arr.Value(i) - } else { - recs[i][field.Name] = nil - } - } - case *arrow.Float32Type: - arr := col.(*array.Float32) - for i := 0; i < arr.Len(); i++ { - if arr.IsValid(i) { - recs[i][field.Name] = arr.Value(i) - } else { - recs[i][field.Name] = nil - } - } - case *arrow.Float64Type: - arr := col.(*array.Float64) - for i := 0; i < arr.Len(); i++ { - if arr.IsValid(i) { - recs[i][field.Name] = arr.Value(i) - } else { - recs[i][field.Name] = nil - } - } - case *arrow.StringType: - arr := col.(*array.String) - for i := 0; i < arr.Len(); i++ { - if arr.IsValid(i) { - recs[i][field.Name] = arr.Value(i) - } else { - recs[i][field.Name] = nil + for i, col := range record.Columns() { + if err := writeData(col.Data(), &recs, []string{e.schema.Field(i).Name}); err != nil { + return err + } + } + + return e.e.Encode(recs) +} + +func writeData(data *array.Data, recs *[]map[string]interface{}, names []string) error { + switch data.DataType().ID() { + case arrow.BOOL: + arr := array.NewBooleanData(data) + for i := 0; i < arr.Len(); i++ { + if arr.IsValid(i) { + if err := deepSet(&(*recs)[i], names, arr.Value(i)); err != nil { + return err + } + } else { + if err := deepSet(&(*recs)[i], names, nil); err != nil { + return err + } + } + } + + case arrow.INT8: + arr := array.NewInt8Data(data) + for i := 0; i < arr.Len(); i++ { + if arr.IsValid(i) { + if err := deepSet(&(*recs)[i], names, arr.Value(i)); err != nil { + return err + } + } else { + if err := deepSet(&(*recs)[i], names, nil); err != nil { + return err + } + } + } + + case arrow.INT16: + arr := array.NewInt16Data(data) + for i := 0; i < arr.Len(); i++ { + if arr.IsValid(i) { + if err := deepSet(&(*recs)[i], names, arr.Value(i)); err != nil { + return err + } + } else { + if err := deepSet(&(*recs)[i], names, nil); err != nil { + return err } } + } + + case arrow.INT32: + arr := array.NewInt32Data(data) + for i := 0; i < arr.Len(); i++ { + if arr.IsValid(i) { + if err := deepSet(&(*recs)[i], names, arr.Value(i)); err != nil { + return err + } + } else { + if err := deepSet(&(*recs)[i], names, nil); err != nil { + return err + } + } + } - // TODO more types + case arrow.INT64: + arr := array.NewInt64Data(data) + for i := 0; i < arr.Len(); i++ { + if arr.IsValid(i) { + if err := deepSet(&(*recs)[i], names, arr.Value(i)); err != nil { + return err + } + } else { + if err := deepSet(&(*recs)[i], names, nil); err != nil { + return err + } + } + } + + case arrow.UINT8: + arr := array.NewUint8Data(data) + for i := 0; i < arr.Len(); i++ { + if arr.IsValid(i) { + if err := deepSet(&(*recs)[i], names, arr.Value(i)); err != nil { + return err + } + } else { + if err := deepSet(&(*recs)[i], names, nil); err != nil { + return err + } + } + } + + case arrow.UINT16: + arr := array.NewUint16Data(data) + for i := 0; i < arr.Len(); i++ { + if arr.IsValid(i) { + if err := deepSet(&(*recs)[i], names, arr.Value(i)); err != nil { + return err + } + } else { + if err := deepSet(&(*recs)[i], names, nil); err != nil { + return err + } + } + } + + case arrow.UINT32: + arr := array.NewUint32Data(data) + for i := 0; i < arr.Len(); i++ { + if arr.IsValid(i) { + if err := deepSet(&(*recs)[i], names, arr.Value(i)); err != nil { + return err + } + } else { + if err := deepSet(&(*recs)[i], names, nil); err != nil { + return err + } + } + } + + case arrow.UINT64: + arr := array.NewUint64Data(data) + for i := 0; i < arr.Len(); i++ { + if arr.IsValid(i) { + if err := deepSet(&(*recs)[i], names, arr.Value(i)); err != nil { + return err + } + } else { + if err := deepSet(&(*recs)[i], names, nil); err != nil { + return err + } + } } + + case arrow.FLOAT32: + arr := array.NewFloat32Data(data) + for i := 0; i < arr.Len(); i++ { + if arr.IsValid(i) { + if err := deepSet(&(*recs)[i], names, arr.Value(i)); err != nil { + return err + } + } else { + if err := deepSet(&(*recs)[i], names, nil); err != nil { + return err + } + } + } + + case arrow.FLOAT64: + arr := array.NewFloat64Data(data) + for i := 0; i < arr.Len(); i++ { + if arr.IsValid(i) { + if err := deepSet(&(*recs)[i], names, arr.Value(i)); err != nil { + return err + } + } else { + if err := deepSet(&(*recs)[i], names, nil); err != nil { + return err + } + } + } + + case arrow.STRING: + arr := array.NewStringData(data) + for i := 0; i < arr.Len(); i++ { + if arr.IsValid(i) { + if err := deepSet(&(*recs)[i], names, arr.Value(i)); err != nil { + return err + } + } else { + if err := deepSet(&(*recs)[i], names, nil); err != nil { + return err + } + } + } + + case arrow.BINARY: + arr := array.NewBinaryData(data) + for i := 0; i < arr.Len(); i++ { + if arr.IsValid(i) { + if err := deepSet(&(*recs)[i], names, arr.Value(i)); err != nil { + return err + } + } else { + if err := deepSet(&(*recs)[i], names, nil); err != nil { + return err + } + } + } + + case arrow.STRUCT: + arr := array.NewStructData(data) + st, stOk := arr.DataType().(*arrow.StructType) + if !stOk { + return fmt.Errorf("unsupported data type %v: %w", arr.DataType(), ErrUnsupportedType) + } + for i := 0; i < arr.Len(); i++ { + if arr.IsValid(i) { + for i := 0; i < arr.NumField(); i++ { + n := st.Field(i).Name + d := arr.Field(i).Data() + if err := writeData(d, recs, append(names, n)); err != nil { + return err + } + } + } else { + if err := deepSet(&(*recs)[i], names, nil); err != nil { + return err + } + } + } + + case arrow.LIST: + arr := array.NewListData(data) + for i := 0; i < arr.Len(); i++ { + if arr.IsValid(i) { + o := i + arr.Offset() + bgn := int64(arr.Offsets()[o]) + end := int64(arr.Offsets()[o+1]) + slice := array.NewSlice(arr.ListValues(), bgn, end) + if err := writeData(slice.Data(), recs, names); err != nil { + return err + } + } else { + if err := deepSet(&(*recs)[i], names, nil); err != nil { + return err + } + } + } + + default: + return ErrUnsupportedType } - return e.e.Encode(recs) + return nil +} + +func deepSet(recv *map[string]interface{}, keys []string, value interface{}) error { + cur := *recv + numKeys := len(keys) + + if numKeys > 1 { + for _, k := range keys[:numKeys-1] { + sub, subOk := (*recv)[k] + if !subOk { + return fmt.Errorf("no entry to %v", strings.Join(keys, ".")) + } + + typed, typedOk := sub.(map[string]interface{}) + if !typedOk { + return fmt.Errorf("unexpected type of value %v", sub) + } + + cur = typed + } + } + + if vv, ok := cur[keys[numKeys-1]]; ok { + if arr, arrOk := vv.([]interface{}); arrOk { + cur[keys[numKeys-1]] = append(arr, value) + } else { + cur[keys[numKeys-1]] = []interface{}{vv, value} + } + } else { + cur[keys[numKeys-1]] = value + } + + return nil } diff --git a/arrow/json/writer_test.go b/arrow/json/writer_test.go index d808b38..e3d33f2 100644 --- a/arrow/json/writer_test.go +++ b/arrow/json/writer_test.go @@ -43,7 +43,7 @@ func Example_writer() { } } -func TestCSVWriter(t *testing.T) { +func TestJsonWriter(t *testing.T) { tests := []struct { name string }{{ @@ -51,12 +51,12 @@ func TestCSVWriter(t *testing.T) { }} for _, test := range tests { t.Run(test.name, func(t *testing.T) { - testCSVWriter(t) + testJsonWriter(t) }) } } -func testCSVWriter(t *testing.T) { +func testJsonWriter(t *testing.T) { f := new(bytes.Buffer) pool := memory.NewCheckedAllocator(memory.NewGoAllocator()) @@ -110,8 +110,8 @@ func testCSVWriter(t *testing.T) { want := strings.ReplaceAll(`[ {"bool":true,"f32":0,"f64":0,"i16":-1,"i32":-1,"i64":-1,"i8":-1,"str":"str-0","u16":0,"u32":0,"u64":0,"u8":0}, -{"bool":false,"f32":0.1,"f64":0.1,"i16":0,"i32":0,"i64":0,"i8":0,"str":"str-1","u16":1,"u32":1,"u64":1,"u8":1} -,{"bool":true,"f32":0.2,"f64":0.2,"i16":1,"i32":1,"i64":1,"i8":1,"str":"str-2","u16":2,"u32":2,"u64":2,"u8":2}, +{"bool":false,"f32":0.1,"f64":0.1,"i16":0,"i32":0,"i64":0,"i8":0,"str":"str-1","u16":1,"u32":1,"u64":1,"u8":1}, +{"bool":true,"f32":0.2,"f64":0.2,"i16":1,"i32":1,"i64":1,"i8":1,"str":"str-2","u16":2,"u32":2,"u64":2,"u8":2}, {"bool":null,"f32":null,"f64":null,"i16":null,"i32":null,"i64":null,"i8":null,"str":null,"u16":null,"u32":null,"u64":null,"u8":null}] `, "\n", "") + "\n"