Skip to content

Commit

Permalink
feats: add prepare stmt option, close #74
Browse files Browse the repository at this point in the history
  • Loading branch information
wentaojin committed Oct 25, 2024
1 parent efad904 commit 455f3a8
Show file tree
Hide file tree
Showing 22 changed files with 955 additions and 420 deletions.
2 changes: 2 additions & 0 deletions component/cli/migrate/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"bytes"
"context"
"fmt"

"github.com/fatih/color"
"github.com/wentaojin/dbms/service"

Expand Down Expand Up @@ -48,6 +49,7 @@ type SqlMigrateParam struct {
EnableCheckpoint bool `toml:"enable-checkpoint" json:"enableCheckpoint"`
EnableConsistentRead bool `toml:"enable-consistent-read" json:"enableConsistentRead"`
EnableSafeMode bool `toml:"enable-safe-mode" json:"enableSafeMode"`
EnablePrepareStmt bool `toml:"enable-prepare-stmt" json:"enablePrepareStmt"`
}

func (s *SqlConfig) String() string {
Expand Down
3 changes: 2 additions & 1 deletion database/data_migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ type IDatabaseDataMigrate interface {
GetDatabaseTableRows(schemaName, tableName string) (uint64, error)
GetDatabaseTableSize(schemaName, tableName string) (float64, error)
GetDatabaseTableChunkTask(taskName, schemaName, tableName string, chunkSize uint64, callTimeout uint64, batchSize int, dataChan chan []map[string]string) error
GetDatabaseTableChunkData(querySQL string, queryArgs []interface{}, batchSize, callTimeout int, dbCharsetS, dbCharsetT, columnDetailO string, dataChan chan []interface{}) error
GetDatabaseTableStmtData(querySQL string, queryArgs []interface{}, batchSize, callTimeout int, dbCharsetS, dbCharsetT, columnDetailO string, dataChan chan []interface{}) error
GetDatabaseTableNonStmtData(taskFlow, querySQL string, queryArgs []interface{}, batchSize, callTimeout int, dbCharsetS, dbCharsetT, columnDetailO string, dataChan chan []interface{}) error
GetDatabaseTableCsvData(querySQL string, queryArgs []interface{}, callTimeout int, taskFlow, dbCharsetS, dbCharsetT, columnDetailO string, escapeBackslash bool, nullValue, separator, delimiter string, dataChan chan []string) error
IDatabaseDataMigrateSnapshot
}
Expand Down
7 changes: 6 additions & 1 deletion database/mysql/data_migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,12 @@ func (d *Database) GetDatabaseTableChunkTask(taskName, schemaName, tableName str
panic("implement me")
}

func (d *Database) GetDatabaseTableChunkData(querySQL string, queryArgs []interface{}, batchSize, callTimeout int, dbCharsetS, dbCharsetT, columnDetailS string, dataChan chan []interface{}) error {
func (d *Database) GetDatabaseTableStmtData(querySQL string, queryArgs []interface{}, batchSize, callTimeout int, dbCharsetS, dbCharsetT, columnDetailS string, dataChan chan []interface{}) error {
//TODO implement me
panic("implement me")
}

func (d *Database) GetDatabaseTableNonStmtData(taskFlow, querySQL string, queryArgs []interface{}, batchSize, callTimeout int, dbCharsetS, dbCharsetT, columnDetailS string, dataChan chan []interface{}) error {
//TODO implement me
panic("implement me")
}
Expand Down
203 changes: 202 additions & 1 deletion database/oracle/data_migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ END;`, taskName)
return nil
}

func (d *Database) GetDatabaseTableChunkData(querySQL string, queryArgs []interface{}, batchSize, callTimeout int, dbCharsetS, dbCharsetT, columnDetailO string, dataChan chan []interface{}) error {
func (d *Database) GetDatabaseTableStmtData(querySQL string, queryArgs []interface{}, batchSize, callTimeout int, dbCharsetS, dbCharsetT, columnDetailO string, dataChan chan []interface{}) error {
var (
databaseTypes []string
err error
Expand Down Expand Up @@ -426,6 +426,207 @@ func (d *Database) GetDatabaseTableChunkData(querySQL string, queryArgs []interf
return nil
}

func (d *Database) GetDatabaseTableNonStmtData(taskFlow, querySQL string, queryArgs []interface{}, batchSize, callTimeout int, dbCharsetS, dbCharsetT, columnDetailO string, dataChan chan []interface{}) error {
var (
databaseTypes []string
err error
)
columnNameOrders := stringutil.StringSplit(columnDetailO, constant.StringSeparatorComma)
columnNameOrdersCounts := len(columnNameOrders)
rowData := make([]string, columnNameOrdersCounts)

batchRowsData := make([]string, 0, batchSize)

batchRowsDataChanTemp := make([]interface{}, 0, 1)

columnNameOrderIndexMap := make(map[string]int, columnNameOrdersCounts)

for i, c := range columnNameOrders {
columnNameOrderIndexMap[c] = i
}

deadline := time.Now().Add(time.Duration(callTimeout) * time.Second)

ctx, cancel := context.WithDeadline(d.Ctx, deadline)
defer cancel()

rows, err := d.QueryContext(ctx, querySQL, queryArgs...)
if err != nil {
return err
}
defer rows.Close()

colTypes, err := rows.ColumnTypes()
if err != nil {
return err
}

for _, ct := range colTypes {
databaseTypes = append(databaseTypes, ct.DatabaseTypeName())
}

// data scan
values := make([]interface{}, columnNameOrdersCounts)
valuePtrs := make([]interface{}, columnNameOrdersCounts)
for i, _ := range columnNameOrders {
valuePtrs[i] = &values[i]
}

for rows.Next() {
err = rows.Scan(valuePtrs...)
if err != nil {
return err
}

for i, colName := range columnNameOrders {
valRes := values[i]
if stringutil.IsValueNil(valRes) {
rowData[columnNameOrderIndexMap[colName]] = `NULL`
} else {
value := reflect.ValueOf(valRes).Interface()
switch val := value.(type) {
case godror.Number:
rfs, err := decimal.NewFromString(val.String())
if err != nil {
return fmt.Errorf("column [%s] datatype [%s] value [%v] NewFromString strconv failed, %v", colName, databaseTypes[i], val, err)
}
rowData[columnNameOrderIndexMap[colName]] = rfs.String()
case *godror.Lob:
lobD, err := val.Hijack()
if err != nil {
return fmt.Errorf("column [%s] datatype [%s] value [%v] hijack failed, %v", colName, databaseTypes[i], val, err)
}
if strings.EqualFold(databaseTypes[i], "BFILE") {
dir, file, err := lobD.GetFileName()
if err != nil {
return fmt.Errorf("column [%s] datatype [%s] value [%v] hijack getfilename failed, %v", colName, databaseTypes[i], val, err)
}
dirPath, err := d.GetDatabaseDirectoryName(dir)
if err != nil {
return fmt.Errorf("column [%s] datatype [%s] value [%v] hijack get directory name failed, %v", colName, databaseTypes[i], val, err)
}
rowData[columnNameOrderIndexMap[colName]] = fmt.Sprintf("'%v'", filepath.Join(dirPath, file))
} else {
// get actual data
lobSize, err := lobD.Size()
if err != nil {
return fmt.Errorf("column [%s] datatype [%s] value [%v] hijack size failed, %v", colName, databaseTypes[i], val, err)
}

buf := make([]byte, lobSize)

var (
res strings.Builder
offset int64
)
for {
count, err := lobD.ReadAt(buf, offset)
if err != nil {
return fmt.Errorf("column [%s] datatype [%s] value [%v] hijack readAt failed, %v", colName, databaseTypes[i], val, err)
}
if int64(count) > lobSize/int64(4) {
count = int(lobSize / 4)
}
offset += int64(count)
res.Write(buf[:count])
if count == 0 {
break
}
}
switch {
case strings.EqualFold(taskFlow, constant.TaskFlowOracleToTiDB) || strings.EqualFold(taskFlow, constant.TaskFlowOracleToMySQL):
convertTargetRaw, err := stringutil.CharsetConvert([]byte(stringutil.SpecialLettersMySQLCompatibleDatabase([]byte(res.String()))), constant.CharsetUTF8MB4, dbCharsetT)
if err != nil {
return fmt.Errorf("column [%s] charset convert failed, %v", colName, err)
}
rowData[columnNameOrderIndexMap[colName]] = fmt.Sprintf("'%v'", stringutil.BytesToString(convertTargetRaw))
default:
return fmt.Errorf("the task_flow [%s] isn't support, please contact author or reselect, %v", taskFlow, err)
}
}
err = lobD.Close()
if err != nil {
return fmt.Errorf("column [%s] datatype [%s] value [%v] hijack close failed, %v", colName, databaseTypes[i], val, err)
}
case string:
if strings.EqualFold(val, "") {
rowData[columnNameOrderIndexMap[colName]] = `NULL`
} else {
convertUtf8Raw, err := stringutil.CharsetConvert([]byte(val), dbCharsetS, constant.CharsetUTF8MB4)
if err != nil {
return fmt.Errorf("column [%s] datatype [%s] value [%v] charset convert failed, %v", colName, databaseTypes[i], val, err)
}

switch {
case strings.EqualFold(taskFlow, constant.TaskFlowOracleToTiDB) || strings.EqualFold(taskFlow, constant.TaskFlowOracleToMySQL):
convertTargetRaw, err := stringutil.CharsetConvert([]byte(stringutil.SpecialLettersMySQLCompatibleDatabase(convertUtf8Raw)), constant.CharsetUTF8MB4, dbCharsetT)
if err != nil {
return fmt.Errorf("column [%s] charset convert failed, %v", colName, err)
}
rowData[columnNameOrderIndexMap[colName]] = fmt.Sprintf("'%v'", stringutil.BytesToString(convertTargetRaw))
default:
return fmt.Errorf("the task_flow [%s] isn't support, please contact author or reselect, %v", taskFlow, err)
}
}
case []uint8:
// binary data -> raw、long raw、blob
switch {
case strings.EqualFold(taskFlow, constant.TaskFlowOracleToTiDB) || strings.EqualFold(taskFlow, constant.TaskFlowOracleToMySQL):
convertTargetRaw, err := stringutil.CharsetConvert([]byte(stringutil.SpecialLettersMySQLCompatibleDatabase(val)), constant.CharsetUTF8MB4, dbCharsetT)
if err != nil {
return fmt.Errorf("column [%s] charset convert failed, %v", colName, err)
}
rowData[columnNameOrderIndexMap[colName]] = fmt.Sprintf("'%v'", stringutil.BytesToString(convertTargetRaw))
default:
return fmt.Errorf("the task_flow [%s] isn't support, please contact author or reselect, %v", taskFlow, err)
}
case int64:
rowData[columnNameOrderIndexMap[colName]] = decimal.NewFromInt(val).String()
case uint64:
rowData[columnNameOrderIndexMap[colName]] = strconv.FormatUint(val, 10)
case float32:
rowData[columnNameOrderIndexMap[colName]] = decimal.NewFromFloat32(val).String()
case float64:
rowData[columnNameOrderIndexMap[colName]] = decimal.NewFromFloat(val).String()
case int32:
rowData[columnNameOrderIndexMap[colName]] = decimal.NewFromInt32(val).String()
default:
return fmt.Errorf("the task_flow [%s] query [%s] column [%s] unsupported type: %T", taskFlow, querySQL, colName, value)
}
}
}

// temporary array
batchRowsData = append(batchRowsData, stringutil.StringJoin(rowData, constant.StringSeparatorComma))

// clear
rowData = make([]string, columnNameOrdersCounts)

// batch
if len(batchRowsData) == batchSize {
batchRowsDataChanTemp = append(batchRowsDataChanTemp, rowData)

dataChan <- batchRowsDataChanTemp

// clear
batchRowsDataChanTemp = make([]interface{}, 0, 1)
batchRowsData = make([]string, 0, batchSize)
}
}

if err = rows.Err(); err != nil {
return err
}

// non-batch batch
if len(batchRowsData) > 0 {
batchRowsDataChanTemp = append(batchRowsDataChanTemp, rowData)
dataChan <- batchRowsDataChanTemp
}

return nil
}

func (d *Database) GetDatabaseTableCsvData(querySQL string, queryArgs []interface{}, callTimeout int, taskFlow, dbCharsetS, dbCharsetT, columnDetailO string, escapeBackslash bool, nullValue, separator, delimiter string, dataChan chan []string) error {
var (
databaseTypes []string
Expand Down
Loading

0 comments on commit 455f3a8

Please sign in to comment.