From 2064afbce20e289d80625fa0b869c429eb24e94c Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Wed, 9 Aug 2023 17:31:51 -0300 Subject: [PATCH 1/7] add batch insert --- db/dialect.go | 3 ++ db/dialect_clickhouse.go | 67 ++++++++++++++++++++++++++++++++++++++ db/dialect_postgres.go | 37 +++++++++++++++++++++ db/flush.go | 46 +++++++++----------------- db/operations.go | 70 +++++++++++++++++++++++++++++++++++++++- db/ops.go | 7 ++++ 6 files changed, 199 insertions(+), 31 deletions(-) diff --git a/db/dialect.go b/db/dialect.go index 6d8bd95..4449d87 100644 --- a/db/dialect.go +++ b/db/dialect.go @@ -2,6 +2,7 @@ package db import ( "context" + "database/sql" "fmt" sink "github.com/streamingfast/substreams-sink" @@ -22,6 +23,8 @@ type dialect interface { DriverSupportRowsAffected() bool GetUpdateCursorQuery(table, moduleHash string, cursor *sink.Cursor, block_num uint64, block_id string) string ParseDatetimeNormalization(value string) string + Flush(tx *sql.Tx, ctx context.Context, l *Loader, outputModuleHash string, cursor *sink.Cursor) (int, error) + OnlyInserts() bool } var driverDialect = map[string]dialect{ diff --git a/db/dialect_clickhouse.go b/db/dialect_clickhouse.go index 9ab5c6a..fa99a3f 100644 --- a/db/dialect_clickhouse.go +++ b/db/dialect_clickhouse.go @@ -2,15 +2,78 @@ package db import ( "context" + "database/sql" "fmt" + "sort" "strings" "github.com/streamingfast/cli" sink "github.com/streamingfast/substreams-sink" + "go.uber.org/zap" ) type clickhouseDialect struct{} +func (d clickhouseDialect) Flush(tx *sql.Tx, ctx context.Context, l *Loader, outputModuleHash string, cursor *sink.Cursor) (int, error) { + var entryCount int + for entriesPair := l.entries.Oldest(); entriesPair != nil; entriesPair = entriesPair.Next() { + tableName := entriesPair.Key + entries := entriesPair.Value + tx, err := l.DB.BeginTx(ctx, nil) + if err != nil { + return entryCount, fmt.Errorf("failed to begin db transaction") + } + + if l.tracer.Enabled() { + l.logger.Debug("flushing table entries", zap.String("table_name", tableName), zap.Int("entry_count", entries.Len())) + } + info := l.tables[tableName] + columns := make([]string, 0, len(info.columnsByName)) + for column := range info.columnsByName { + columns = append(columns, column) + } + sort.Strings(columns) + query := fmt.Sprintf( + "INSERT INTO %s.%s (%s)", + EscapeIdentifier(l.schema), + EscapeIdentifier(tableName), + strings.Join(columns, ",")) + // fmt.Println(query) + batch, err := tx.Prepare(query) + if err != nil { + return entryCount, fmt.Errorf("failed to prepare insert into %q: %w", tableName, err) + } + for entryPair := entries.Oldest(); entryPair != nil; entryPair = entryPair.Next() { + entry := entryPair.Value + + if err != nil { + return entryCount, fmt.Errorf("failed to get query: %w", err) + } + + if l.tracer.Enabled() { + l.logger.Debug("adding query from operation to transaction", zap.Stringer("op", entry), zap.String("query", query)) + } + + values, err := entry.getValues() + if err != nil { + return entryCount, fmt.Errorf("failed to get values: %w", err) + } + + if _, err := batch.ExecContext(ctx, values...); err != nil { + return entryCount, fmt.Errorf("executing for entry %q: %w", values, err) + } + } + + // fmt.Println("flushing batch") + if err := tx.Commit(); err != nil { + return entryCount, fmt.Errorf("failed to commit db transaction: %w", err) + } + entryCount += entries.Len() + } + + return entryCount, nil +} + func (d clickhouseDialect) GetCreateCursorQuery(schema string) string { return fmt.Sprintf(cli.Dedent(` CREATE TABLE IF NOT EXISTS %s.%s @@ -48,3 +111,7 @@ func (d clickhouseDialect) ParseDatetimeNormalization(value string) string { func (d clickhouseDialect) DriverSupportRowsAffected() bool { return false } + +func (d clickhouseDialect) OnlyInserts() bool { + return true +} diff --git a/db/dialect_postgres.go b/db/dialect_postgres.go index 3509697..a73e1de 100644 --- a/db/dialect_postgres.go +++ b/db/dialect_postgres.go @@ -2,14 +2,47 @@ package db import ( "context" + "database/sql" "fmt" "github.com/streamingfast/cli" sink "github.com/streamingfast/substreams-sink" + "go.uber.org/zap" ) type postgresDialect struct{} +func (d postgresDialect) Flush(tx *sql.Tx, ctx context.Context, l *Loader, outputModuleHash string, cursor *sink.Cursor) (int, error) { + var entryCount int + for entriesPair := l.entries.Oldest(); entriesPair != nil; entriesPair = entriesPair.Next() { + tableName := entriesPair.Key + entries := entriesPair.Value + + if l.tracer.Enabled() { + l.logger.Debug("flushing table entries", zap.String("table_name", tableName), zap.Int("entry_count", entries.Len())) + } + for entryPair := entries.Oldest(); entryPair != nil; entryPair = entryPair.Next() { + entry := entryPair.Value + + query, err := entry.query(l.getDialect()) + if err != nil { + return 0, fmt.Errorf("failed to get query: %w", err) + } + + if l.tracer.Enabled() { + l.logger.Debug("adding query from operation to transaction", zap.Stringer("op", entry), zap.String("query", query)) + } + + if _, err := tx.ExecContext(ctx, query); err != nil { + return 0, fmt.Errorf("executing query %q: %w", query, err) + } + } + entryCount += entries.Len() + } + + return entryCount, nil +} + func (d postgresDialect) GetCreateCursorQuery(schema string) string { return fmt.Sprintf(cli.Dedent(` create table if not exists %s.%s @@ -42,3 +75,7 @@ func (d postgresDialect) ParseDatetimeNormalization(value string) string { func (d postgresDialect) DriverSupportRowsAffected() bool { return true } + +func (d postgresDialect) OnlyInserts() bool { + return false +} diff --git a/db/flush.go b/db/flush.go index cf010f4..eec2554 100644 --- a/db/flush.go +++ b/db/flush.go @@ -3,6 +3,8 @@ package db import ( "context" "fmt" + // "sort" + // "strings" "time" "github.com/ClickHouse/clickhouse-go/v2" @@ -13,7 +15,6 @@ import ( func (l *Loader) Flush(ctx context.Context, outputModuleHash string, cursor *sink.Cursor) (err error) { ctx = clickhouse.Context(context.Background(), clickhouse.WithStdAsync(false)) startAt := time.Now() - tx, err := l.DB.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("failed to being db transaction: %w", err) @@ -25,36 +26,10 @@ func (l *Loader) Flush(ctx context.Context, outputModuleHash string, cursor *sin } } }() - - var entryCount int - for entriesPair := l.entries.Oldest(); entriesPair != nil; entriesPair = entriesPair.Next() { - tableName := entriesPair.Key - entries := entriesPair.Value - - if l.tracer.Enabled() { - l.logger.Debug("flushing table entries", zap.String("table_name", tableName), zap.Int("entry_count", entries.Len())) - } - - for entryPair := entries.Oldest(); entryPair != nil; entryPair = entryPair.Next() { - entry := entryPair.Value - - query, err := entry.query(l.getDialect()) - if err != nil { - return fmt.Errorf("failed to get query: %w", err) - } - - if l.tracer.Enabled() { - l.logger.Debug("adding query from operation to transaction", zap.Stringer("op", entry), zap.String("query", query)) - } - - if _, err := tx.ExecContext(ctx, query); err != nil { - return fmt.Errorf("executing query %q: %w", query, err) - } - } - - entryCount += entries.Len() + entryCount, err := l.getDialect().Flush(tx, ctx, l, outputModuleHash, cursor) + if err != nil { + return fmt.Errorf("dialect flush: %w", err) } - entryCount += 1 if err := l.UpdateCursor(ctx, tx, outputModuleHash, cursor); err != nil { return fmt.Errorf("update cursor: %w", err) @@ -69,6 +44,17 @@ func (l *Loader) Flush(ctx context.Context, outputModuleHash string, cursor *sin return nil } +func onlyOperation(op OperationType, entries *OrderedMap[string, *Operation]) bool { + sameOperation := true + for entryPair := entries.Oldest(); entryPair != nil; entryPair = entryPair.Next() { + if entryPair.Value.opType != op { + sameOperation = false + break + } + } + return sameOperation +} + func (l *Loader) reset() { for entriesPair := l.entries.Oldest(); entriesPair != nil; entriesPair = entriesPair.Next() { l.entries.Set(entriesPair.Key, NewOrderedMap[string, *Operation]()) diff --git a/db/operations.go b/db/operations.go index 765b810..c0350fb 100644 --- a/db/operations.go +++ b/db/operations.go @@ -2,8 +2,10 @@ package db import ( "fmt" + "math/big" "reflect" "regexp" + "sort" "strconv" "strings" "time" @@ -73,6 +75,25 @@ func (o *Operation) mergeData(newData map[string]string) error { return nil } +func (o *Operation) getValues() ([]any, error) { + columns := make([]string, len(o.data)) + i := 0 + for column := range o.data { + columns[i] = column + i++ + } + sort.Strings(columns) + values := make([]any, len(o.data)) + for i, v := range columns { + convertedType, err := convertToType(o.data[v], o.table.columnsByName[v].scanType) + if err != nil { + return nil, fmt.Errorf("converting value %q to type %q: %w", o.data[v], o.table.columnsByName[v].scanType, err) + } + values[i] = convertedType + } + return values, nil +} + func (o *Operation) query(d dialect) (string, error) { var columns, values []string if o.opType == OperationTypeInsert || o.opType == OperationTypeUpdate { @@ -163,6 +184,54 @@ func prepareColValues(d dialect, table *TableInfo, colValues map[string]string) var integerRegex = regexp.MustCompile(`^\d+$`) var reflectTypeTime = reflect.TypeOf(time.Time{}) +func convertToType(value string, valueType reflect.Type) (any, error) { + switch valueType.Kind() { + case reflect.String: + return value, nil + case reflect.Slice: + return value, nil + case reflect.Bool: + return strconv.ParseBool(value) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.ParseInt(value, 10, 0) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint64: + return strconv.ParseUint(value, 10, 0) + case reflect.Uint32: + v, err := strconv.ParseUint(value, 10, 32) + return uint32(v), err + case reflect.Float32, reflect.Float64: + return strconv.ParseFloat(value, 10) + case reflect.Struct: + if valueType == reflectTypeTime { + if integerRegex.MatchString(value) { + i, err := strconv.Atoi(value) + if err != nil { + return "", fmt.Errorf("could not convert %s to int: %w", value, err) + } + + return int64(i), nil + } + + v, err := time.Parse("2006-01-02T15:04:05Z", value) + if err != nil { + return "", fmt.Errorf("could not convert %s to time: %w", value, err) + } + return v.Unix(), nil + } + return "", fmt.Errorf("unsupported struct type %s", valueType) + + case reflect.Ptr: + if valueType.String() == "*big.Int" { + newInt := new(big.Int) + newInt.SetString(value, 10) + return newInt, nil + } + return "", fmt.Errorf("unsupported pointer type %s", valueType) + default: + return value, nil + } +} + // Format based on type, value returned unescaped func normalizeValueType(value string, valueType reflect.Type, d dialect) (string, error) { switch valueType.Kind() { @@ -203,7 +272,6 @@ func normalizeValueType(value string, valueType reflect.Type, d dialect) (string } return "", fmt.Errorf("unsupported struct type %s", valueType) - default: // It's a column's type the schema parsing don't know how to represents as // a Go type. In that case, we pass it unmodified to the database engine. It diff --git a/db/ops.go b/db/ops.go index 6493135..2c66a17 100644 --- a/db/ops.go +++ b/db/ops.go @@ -93,6 +93,9 @@ func (l *Loader) GetPrimaryKey(tableName string, pk string) (map[string]string, // Update a row in the DB, it is assumed the table exists, you can do a // check before with HasTable() func (l *Loader) Update(tableName string, primaryKey map[string]string, data map[string]string) error { + if l.getDialect().OnlyInserts() { + return fmt.Errorf("update operation is not supported by the current database") + } uniqueID := createRowUniqueID(primaryKey) if l.tracer.Enabled() { @@ -141,6 +144,10 @@ func (l *Loader) Update(tableName string, primaryKey map[string]string, data map // Delete a row in the DB, it is assumed the table exists, you can do a // check before with HasTable() func (l *Loader) Delete(tableName string, primaryKey map[string]string) error { + if l.getDialect().OnlyInserts() { + return fmt.Errorf("update operation is not supported by the current database") + } + uniqueID := createRowUniqueID(primaryKey) if l.tracer.Enabled() { l.logger.Debug("processing delete operation", zap.String("table_name", tableName), zap.String("primary_key", uniqueID)) From 7ccb1744cfccb3ee8e4c3e8a3799787b4def31f2 Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Wed, 9 Aug 2023 17:37:02 -0300 Subject: [PATCH 2/7] remove unused function --- db/flush.go | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/db/flush.go b/db/flush.go index eec2554..22f50c9 100644 --- a/db/flush.go +++ b/db/flush.go @@ -44,17 +44,6 @@ func (l *Loader) Flush(ctx context.Context, outputModuleHash string, cursor *sin return nil } -func onlyOperation(op OperationType, entries *OrderedMap[string, *Operation]) bool { - sameOperation := true - for entryPair := entries.Oldest(); entryPair != nil; entryPair = entryPair.Next() { - if entryPair.Value.opType != op { - sameOperation = false - break - } - } - return sameOperation -} - func (l *Loader) reset() { for entriesPair := l.entries.Oldest(); entriesPair != nil; entriesPair = entriesPair.Next() { l.entries.Set(entriesPair.Key, NewOrderedMap[string, *Operation]()) From de410d4dbc3843f1ab65c6a402e8b859ffbe80a0 Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Wed, 9 Aug 2023 17:38:08 -0300 Subject: [PATCH 3/7] fix message operation --- db/ops.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/db/ops.go b/db/ops.go index 2c66a17..fb7c974 100644 --- a/db/ops.go +++ b/db/ops.go @@ -145,7 +145,7 @@ func (l *Loader) Update(tableName string, primaryKey map[string]string, data map // check before with HasTable() func (l *Loader) Delete(tableName string, primaryKey map[string]string) error { if l.getDialect().OnlyInserts() { - return fmt.Errorf("update operation is not supported by the current database") + return fmt.Errorf("delete operation is not supported by the current database") } uniqueID := createRowUniqueID(primaryKey) From 460e2546de9c0b36ee001108abea2acda697d16c Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Wed, 9 Aug 2023 17:39:07 -0300 Subject: [PATCH 4/7] remove comments --- db/dialect_clickhouse.go | 2 -- db/flush.go | 2 -- 2 files changed, 4 deletions(-) diff --git a/db/dialect_clickhouse.go b/db/dialect_clickhouse.go index fa99a3f..7383301 100644 --- a/db/dialect_clickhouse.go +++ b/db/dialect_clickhouse.go @@ -38,7 +38,6 @@ func (d clickhouseDialect) Flush(tx *sql.Tx, ctx context.Context, l *Loader, out EscapeIdentifier(l.schema), EscapeIdentifier(tableName), strings.Join(columns, ",")) - // fmt.Println(query) batch, err := tx.Prepare(query) if err != nil { return entryCount, fmt.Errorf("failed to prepare insert into %q: %w", tableName, err) @@ -64,7 +63,6 @@ func (d clickhouseDialect) Flush(tx *sql.Tx, ctx context.Context, l *Loader, out } } - // fmt.Println("flushing batch") if err := tx.Commit(); err != nil { return entryCount, fmt.Errorf("failed to commit db transaction: %w", err) } diff --git a/db/flush.go b/db/flush.go index 22f50c9..87a7a37 100644 --- a/db/flush.go +++ b/db/flush.go @@ -3,8 +3,6 @@ package db import ( "context" "fmt" - // "sort" - // "strings" "time" "github.com/ClickHouse/clickhouse-go/v2" From 8c1ed3e22ae0c9efefeaa0f4d954fe5485751120 Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Fri, 25 Aug 2023 13:03:28 -0300 Subject: [PATCH 5/7] chore(clickhouse): add comments to flush --- db/dialect_clickhouse.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/db/dialect_clickhouse.go b/db/dialect_clickhouse.go index 7383301..ca7af16 100644 --- a/db/dialect_clickhouse.go +++ b/db/dialect_clickhouse.go @@ -14,6 +14,11 @@ import ( type clickhouseDialect struct{} +// Clickhouse should be used to insert a lot of data in batches. The current official clickhouse +// driver doesn't support Transactions for multiple tables. The only way to add in batches is +// creating a transaction for a table, adding all rows and commiting it. +// +// That's why two different Flush() functions are needed depending on the dialect. func (d clickhouseDialect) Flush(tx *sql.Tx, ctx context.Context, l *Loader, outputModuleHash string, cursor *sink.Cursor) (int, error) { var entryCount int for entriesPair := l.entries.Oldest(); entriesPair != nil; entriesPair = entriesPair.Next() { @@ -53,7 +58,7 @@ func (d clickhouseDialect) Flush(tx *sql.Tx, ctx context.Context, l *Loader, out l.logger.Debug("adding query from operation to transaction", zap.Stringer("op", entry), zap.String("query", query)) } - values, err := entry.getValues() + values, err := entry.getValues() if err != nil { return entryCount, fmt.Errorf("failed to get values: %w", err) } From 70c52d87b83ce773a918572a966b30ad0d3297cf Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Fri, 25 Aug 2023 13:18:05 -0300 Subject: [PATCH 6/7] refact: move code to specific dialect --- db/dialect_clickhouse.go | 77 +++++++++++++- db/dialect_postgres.go | 146 ++++++++++++++++++++++++++- db/operations.go | 210 --------------------------------------- db/operations_test.go | 4 +- 4 files changed, 222 insertions(+), 215 deletions(-) diff --git a/db/dialect_clickhouse.go b/db/dialect_clickhouse.go index ca7af16..bda7758 100644 --- a/db/dialect_clickhouse.go +++ b/db/dialect_clickhouse.go @@ -4,8 +4,12 @@ import ( "context" "database/sql" "fmt" + "math/big" + "reflect" "sort" + "strconv" "strings" + "time" "github.com/streamingfast/cli" sink "github.com/streamingfast/substreams-sink" @@ -17,8 +21,6 @@ type clickhouseDialect struct{} // Clickhouse should be used to insert a lot of data in batches. The current official clickhouse // driver doesn't support Transactions for multiple tables. The only way to add in batches is // creating a transaction for a table, adding all rows and commiting it. -// -// That's why two different Flush() functions are needed depending on the dialect. func (d clickhouseDialect) Flush(tx *sql.Tx, ctx context.Context, l *Loader, outputModuleHash string, cursor *sink.Cursor) (int, error) { var entryCount int for entriesPair := l.entries.Oldest(); entriesPair != nil; entriesPair = entriesPair.Next() { @@ -58,7 +60,7 @@ func (d clickhouseDialect) Flush(tx *sql.Tx, ctx context.Context, l *Loader, out l.logger.Debug("adding query from operation to transaction", zap.Stringer("op", entry), zap.String("query", query)) } - values, err := entry.getValues() + values, err := convertOpToClickhouseValues(entry) if err != nil { return entryCount, fmt.Errorf("failed to get values: %w", err) } @@ -118,3 +120,72 @@ func (d clickhouseDialect) DriverSupportRowsAffected() bool { func (d clickhouseDialect) OnlyInserts() bool { return true } + + +func convertOpToClickhouseValues(o *Operation) ([]any, error) { + columns := make([]string, len(o.data)) + i := 0 + for column := range o.data { + columns[i] = column + i++ + } + sort.Strings(columns) + values := make([]any, len(o.data)) + for i, v := range columns { + convertedType, err := convertToType(o.data[v], o.table.columnsByName[v].scanType) + if err != nil { + return nil, fmt.Errorf("converting value %q to type %q: %w", o.data[v], o.table.columnsByName[v].scanType, err) + } + values[i] = convertedType + } + return values, nil +} + + +func convertToType(value string, valueType reflect.Type) (any, error) { + switch valueType.Kind() { + case reflect.String: + return value, nil + case reflect.Slice: + return value, nil + case reflect.Bool: + return strconv.ParseBool(value) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.ParseInt(value, 10, 0) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint64: + return strconv.ParseUint(value, 10, 0) + case reflect.Uint32: + v, err := strconv.ParseUint(value, 10, 32) + return uint32(v), err + case reflect.Float32, reflect.Float64: + return strconv.ParseFloat(value, 10) + case reflect.Struct: + if valueType == reflectTypeTime { + if integerRegex.MatchString(value) { + i, err := strconv.Atoi(value) + if err != nil { + return "", fmt.Errorf("could not convert %s to int: %w", value, err) + } + + return int64(i), nil + } + + v, err := time.Parse("2006-01-02T15:04:05Z", value) + if err != nil { + return "", fmt.Errorf("could not convert %s to time: %w", value, err) + } + return v.Unix(), nil + } + return "", fmt.Errorf("unsupported struct type %s", valueType) + + case reflect.Ptr: + if valueType.String() == "*big.Int" { + newInt := new(big.Int) + newInt.SetString(value, 10) + return newInt, nil + } + return "", fmt.Errorf("unsupported pointer type %s", valueType) + default: + return value, nil + } +} diff --git a/db/dialect_postgres.go b/db/dialect_postgres.go index a73e1de..0a08ccf 100644 --- a/db/dialect_postgres.go +++ b/db/dialect_postgres.go @@ -4,10 +4,16 @@ import ( "context" "database/sql" "fmt" + "reflect" + "strconv" + "strings" + "time" "github.com/streamingfast/cli" sink "github.com/streamingfast/substreams-sink" "go.uber.org/zap" + + "golang.org/x/exp/maps" ) type postgresDialect struct{} @@ -24,7 +30,7 @@ func (d postgresDialect) Flush(tx *sql.Tx, ctx context.Context, l *Loader, outpu for entryPair := entries.Oldest(); entryPair != nil; entryPair = entryPair.Next() { entry := entryPair.Value - query, err := entry.query(l.getDialect()) + query, err := d.prepareStatement(entry) if err != nil { return 0, fmt.Errorf("failed to get query: %w", err) } @@ -79,3 +85,141 @@ func (d postgresDialect) DriverSupportRowsAffected() bool { func (d postgresDialect) OnlyInserts() bool { return false } + +func (d *postgresDialect) prepareStatement(o *Operation) (string, error) { + var columns, values []string + if o.opType == OperationTypeInsert || o.opType == OperationTypeUpdate { + var err error + columns, values, err = d.prepareColValues(o.table, o.data) + if err != nil { + return "", fmt.Errorf("preparing column & values: %w", err) + } + } + + switch o.opType { + case OperationTypeInsert: + return fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", + o.table.identifier, + strings.Join(columns, ","), + strings.Join(values, ","), + ), nil + + case OperationTypeUpdate: + updates := make([]string, len(columns)) + for i := 0; i < len(columns); i++ { + updates[i] = fmt.Sprintf("%s=%s", columns[i], values[i]) + } + + primaryKeySelector := getPrimaryKeyWhereClause(o.primaryKey) + return fmt.Sprintf("UPDATE %s SET %s WHERE %s", + o.table.identifier, + strings.Join(updates, ", "), + primaryKeySelector, + ), nil + + case OperationTypeDelete: + primaryKeyWhereClause := getPrimaryKeyWhereClause(o.primaryKey) + return fmt.Sprintf("DELETE FROM %s WHERE %s", + o.table.identifier, + primaryKeyWhereClause, + ), nil + + default: + panic(fmt.Errorf("unknown operation type %q", o.opType)) + } +} + +func (d *postgresDialect) prepareColValues(table *TableInfo, colValues map[string]string) (columns []string, values []string, err error) { + if len(colValues) == 0 { + return + } + + columns = make([]string, len(colValues)) + values = make([]string, len(colValues)) + + i := 0 + for columnName, value := range colValues { + columnInfo, found := table.columnsByName[columnName] + if !found { + return nil, nil, fmt.Errorf("cannot find column %q for table %q (valid columns are %q)", columnName, table.identifier, strings.Join(maps.Keys(table.columnsByName), ", ")) + } + + normalizedValue, err := d.normalizeValueType(value, columnInfo.scanType) + if err != nil { + return nil, nil, fmt.Errorf("getting sql value from table %s for column %q raw value %q: %w", table.identifier, columnName, value, err) + } + + columns[i] = columnInfo.escapedName + values[i] = normalizedValue + + i++ + } + return +} + +func getPrimaryKeyWhereClause(primaryKey map[string]string) string { + // Avoid any allocation if there is a single primary key + if len(primaryKey) == 1 { + for key, value := range primaryKey { + return EscapeIdentifier(key) + " = " + escapeStringValue(value) + } + } + + reg := make([]string, 0, len(primaryKey)) + for key, value := range primaryKey { + reg = append(reg, EscapeIdentifier(key)+" = "+escapeStringValue(value)) + } + + return strings.Join(reg[:], " AND ") +} + +// Format based on type, value returned unescaped +func (d *postgresDialect) normalizeValueType(value string, valueType reflect.Type) (string, error) { + switch valueType.Kind() { + case reflect.String: + // replace unicode null character with empty string + value = strings.ReplaceAll(value, "\u0000", "") + return escapeStringValue(value), nil + + // BYTES in Postgres must be escaped, we receive a Vec from substreams + case reflect.Slice: + return escapeStringValue(value), nil + + case reflect.Bool: + return fmt.Sprintf("'%s'", value), nil + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return value, nil + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return value, nil + + case reflect.Float32, reflect.Float64: + return value, nil + + case reflect.Struct: + if valueType == reflectTypeTime { + if integerRegex.MatchString(value) { + i, err := strconv.Atoi(value) + if err != nil { + return "", fmt.Errorf("could not convert %s to int: %w", value, err) + } + + return escapeStringValue(time.Unix(int64(i), 0).Format(time.RFC3339)), nil + } + + // It's a plain string, parse by dialect it and pass it to the database + return d.ParseDatetimeNormalization(value), nil + } + + return "", fmt.Errorf("unsupported struct type %s", valueType) + default: + // It's a column's type the schema parsing don't know how to represents as + // a Go type. In that case, we pass it unmodified to the database engine. It + // will be the responsibility of the one sending the data to correctly represent + // it in the way accepted by the database. + // + // In most cases, it going to just work. + return value, nil + } +} diff --git a/db/operations.go b/db/operations.go index c0350fb..5acb356 100644 --- a/db/operations.go +++ b/db/operations.go @@ -2,15 +2,10 @@ package db import ( "fmt" - "math/big" "reflect" "regexp" - "sort" - "strconv" "strings" "time" - - "golang.org/x/exp/maps" ) type TypeGetter func(tableName string, columnName string) (reflect.Type, error) @@ -75,214 +70,9 @@ func (o *Operation) mergeData(newData map[string]string) error { return nil } -func (o *Operation) getValues() ([]any, error) { - columns := make([]string, len(o.data)) - i := 0 - for column := range o.data { - columns[i] = column - i++ - } - sort.Strings(columns) - values := make([]any, len(o.data)) - for i, v := range columns { - convertedType, err := convertToType(o.data[v], o.table.columnsByName[v].scanType) - if err != nil { - return nil, fmt.Errorf("converting value %q to type %q: %w", o.data[v], o.table.columnsByName[v].scanType, err) - } - values[i] = convertedType - } - return values, nil -} - -func (o *Operation) query(d dialect) (string, error) { - var columns, values []string - if o.opType == OperationTypeInsert || o.opType == OperationTypeUpdate { - var err error - columns, values, err = prepareColValues(d, o.table, o.data) - if err != nil { - return "", fmt.Errorf("preparing column & values: %w", err) - } - } - - switch o.opType { - case OperationTypeInsert: - return fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", - o.table.identifier, - strings.Join(columns, ","), - strings.Join(values, ","), - ), nil - - case OperationTypeUpdate: - updates := make([]string, len(columns)) - for i := 0; i < len(columns); i++ { - updates[i] = fmt.Sprintf("%s=%s", columns[i], values[i]) - } - - primaryKeySelector := getPrimaryKeyWhereClause(o.primaryKey) - return fmt.Sprintf("UPDATE %s SET %s WHERE %s", - o.table.identifier, - strings.Join(updates, ", "), - primaryKeySelector, - ), nil - - case OperationTypeDelete: - primaryKeyWhereClause := getPrimaryKeyWhereClause(o.primaryKey) - return fmt.Sprintf("DELETE FROM %s WHERE %s", - o.table.identifier, - primaryKeyWhereClause, - ), nil - - default: - panic(fmt.Errorf("unknown operation type %q", o.opType)) - } -} - -func getPrimaryKeyWhereClause(primaryKey map[string]string) string { - // Avoid any allocation if there is a single primary key - if len(primaryKey) == 1 { - for key, value := range primaryKey { - return EscapeIdentifier(key) + " = " + escapeStringValue(value) - } - } - - reg := make([]string, 0, len(primaryKey)) - for key, value := range primaryKey { - reg = append(reg, EscapeIdentifier(key)+" = "+escapeStringValue(value)) - } - - return strings.Join(reg[:], " AND ") -} - -func prepareColValues(d dialect, table *TableInfo, colValues map[string]string) (columns []string, values []string, err error) { - if len(colValues) == 0 { - return - } - - columns = make([]string, len(colValues)) - values = make([]string, len(colValues)) - - i := 0 - for columnName, value := range colValues { - columnInfo, found := table.columnsByName[columnName] - if !found { - return nil, nil, fmt.Errorf("cannot find column %q for table %q (valid columns are %q)", columnName, table.identifier, strings.Join(maps.Keys(table.columnsByName), ", ")) - } - - normalizedValue, err := normalizeValueType(value, columnInfo.scanType, d) - if err != nil { - return nil, nil, fmt.Errorf("getting sql value from table %s for column %q raw value %q: %w", table.identifier, columnName, value, err) - } - - columns[i] = columnInfo.escapedName - values[i] = normalizedValue - - i++ - } - return -} - var integerRegex = regexp.MustCompile(`^\d+$`) var reflectTypeTime = reflect.TypeOf(time.Time{}) -func convertToType(value string, valueType reflect.Type) (any, error) { - switch valueType.Kind() { - case reflect.String: - return value, nil - case reflect.Slice: - return value, nil - case reflect.Bool: - return strconv.ParseBool(value) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return strconv.ParseInt(value, 10, 0) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint64: - return strconv.ParseUint(value, 10, 0) - case reflect.Uint32: - v, err := strconv.ParseUint(value, 10, 32) - return uint32(v), err - case reflect.Float32, reflect.Float64: - return strconv.ParseFloat(value, 10) - case reflect.Struct: - if valueType == reflectTypeTime { - if integerRegex.MatchString(value) { - i, err := strconv.Atoi(value) - if err != nil { - return "", fmt.Errorf("could not convert %s to int: %w", value, err) - } - - return int64(i), nil - } - - v, err := time.Parse("2006-01-02T15:04:05Z", value) - if err != nil { - return "", fmt.Errorf("could not convert %s to time: %w", value, err) - } - return v.Unix(), nil - } - return "", fmt.Errorf("unsupported struct type %s", valueType) - - case reflect.Ptr: - if valueType.String() == "*big.Int" { - newInt := new(big.Int) - newInt.SetString(value, 10) - return newInt, nil - } - return "", fmt.Errorf("unsupported pointer type %s", valueType) - default: - return value, nil - } -} - -// Format based on type, value returned unescaped -func normalizeValueType(value string, valueType reflect.Type, d dialect) (string, error) { - switch valueType.Kind() { - case reflect.String: - // replace unicode null character with empty string - value = strings.ReplaceAll(value, "\u0000", "") - return escapeStringValue(value), nil - - // BYTES in Postgres must be escaped, we receive a Vec from substreams - case reflect.Slice: - return escapeStringValue(value), nil - - case reflect.Bool: - return fmt.Sprintf("'%s'", value), nil - - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return value, nil - - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return value, nil - - case reflect.Float32, reflect.Float64: - return value, nil - - case reflect.Struct: - if valueType == reflectTypeTime { - if integerRegex.MatchString(value) { - i, err := strconv.Atoi(value) - if err != nil { - return "", fmt.Errorf("could not convert %s to int: %w", value, err) - } - - return escapeStringValue(time.Unix(int64(i), 0).Format(time.RFC3339)), nil - } - - // It's a plain string, parse by dialect it and pass it to the database - return d.ParseDatetimeNormalization(value), nil - } - - return "", fmt.Errorf("unsupported struct type %s", valueType) - default: - // It's a column's type the schema parsing don't know how to represents as - // a Go type. In that case, we pass it unmodified to the database engine. It - // will be the responsibility of the one sending the data to correctly represent - // it in the way accepted by the database. - // - // In most cases, it going to just work. - return value, nil - } -} - func EscapeIdentifier(valueToEscape string) string { if strings.Contains(valueToEscape, `"`) { valueToEscape = strings.ReplaceAll(valueToEscape, `"`, `""`) diff --git a/db/operations_test.go b/db/operations_test.go index bf194c1..268df7a 100644 --- a/db/operations_test.go +++ b/db/operations_test.go @@ -154,7 +154,9 @@ func Test_prepareColValues(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotColumns, gotValues, err := prepareColValues(postgresDialect{}, tt.args.table, tt.args.colValues) + dialect := postgresDialect{} + + gotColumns, gotValues, err := dialect.prepareColValues(tt.args.table, tt.args.colValues) tt.assertion(t, err) assert.Equal(t, tt.wantColumns, gotColumns) assert.Equal(t, tt.wantValues, gotValues) From b8392ea264c50dc76ed6012515e6bb36e16c4469 Mon Sep 17 00:00:00 2001 From: Gustavo Inacio Date: Fri, 25 Aug 2023 13:24:47 -0300 Subject: [PATCH 7/7] update CHANGELOG.md --- CHANGELOG.md | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5aec802..3311efc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,9 +12,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Added newer method of populating the database via CSV Newer commands: - - `generate_csv`: Generates CSVs for each table - - `insert_csv`: Injects generated CSV rows for - - `inject_cursor`: Injects the cursor from a file into database + - `generate_csv`: Generates CSVs for each table + - `insert_csv`: Injects generated CSV rows for
+ - `inject_cursor`: Injects the cursor from a file into database + + +* Added driver abstraction + +* Added Clickhouse as the second driver. + +You can connect to Clickhouse by using the following DSN: + +- Not encrypted: `clickhouse://:9000/?username=&password=` +- Encrypted: `clickhouse://:9440/?secure=true&skip_verify=true&username=&password=` + +If you want to send custom args to the connection, you can use by sending as query params. + ## v2.4.0