diff --git a/backend/executor.go b/backend/executor.go index 51c8095..dd578f3 100644 --- a/backend/executor.go +++ b/backend/executor.go @@ -64,7 +64,7 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row ctx.GetLogger().WithFields(logrus.Fields{ "Query": ctx.Query(), "NodeType": fmt.Sprintf("%T", n), - }).Trace("Building node:", n) + }).Traceln("Building node:", n) // TODO; find a better way to fallback to the base builder switch n.(type) { diff --git a/backend/loaddata.go b/backend/loaddata.go index d0b5078..7425077 100644 --- a/backend/loaddata.go +++ b/backend/loaddata.go @@ -14,6 +14,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/go-mysql-server/sql/types" + "github.com/dolthub/vitess/go/vt/proto/query" ) const isUnixSystem = runtime.GOOS == "linux" || @@ -244,14 +245,8 @@ func columnTypeHints(b *strings.Builder, dst sql.Table, schema sql.Schema, colNa if i > 0 { b.WriteString(", ") } - b.WriteString(catalog.QuoteIdentifierANSI(col.Name)) - b.WriteString(": ") - if dt, err := catalog.DuckdbDataType(col.Type); err != nil { + if err := columnTypeHint(b, col); err != nil { return err - } else { - b.WriteString(`'`) - b.WriteString(dt.Name()) - b.WriteString(`'`) } } b.WriteString("}") @@ -262,18 +257,12 @@ func columnTypeHints(b *strings.Builder, dst sql.Table, schema sql.Schema, colNa if i > 0 { b.WriteString(", ") } - b.WriteString(catalog.QuoteIdentifierANSI(col)) - b.WriteString(": ") idx := schema.IndexOf(col, dst.Name()) // O(n^2) but n := # of columns is usually small if idx < 0 { return sql.ErrTableColumnNotFound.New(dst.Name(), col) } - if dt, err := catalog.DuckdbDataType(schema[idx].Type); err != nil { + if err := columnTypeHint(b, schema[idx]); err != nil { return err - } else { - b.WriteString(`'`) - b.WriteString(dt.Name()) - b.WriteString(`'`) } } @@ -281,6 +270,23 @@ func columnTypeHints(b *strings.Builder, dst sql.Table, schema sql.Schema, colNa return nil } +func columnTypeHint(b *strings.Builder, col *sql.Column) error { + b.WriteString(catalog.QuoteIdentifierANSI(col.Name)) + b.WriteString(": ") + if dt, err := catalog.DuckdbDataType(col.Type); err != nil { + return err + } else { + b.WriteString(`'`) + if col.Type.Type() == query.Type_ENUM { + b.WriteString(`VARCHAR`) + } else { + b.WriteString(dt.Name()) + } + b.WriteString(`'`) + } + return nil +} + // isUnderSecureFileDir ensures that fileStr is under secureFileDir or a subdirectory of secureFileDir, errors otherwise // Copied from https://github.com/dolthub/go-mysql-server/blob/main/sql/rowexec/rel.go func isUnderSecureFileDir(secureFileDir interface{}, fileStr string) error { diff --git a/catalog/type_mapping.go b/catalog/type_mapping.go index 6a6266e..3e199bb 100644 --- a/catalog/type_mapping.go +++ b/catalog/type_mapping.go @@ -99,8 +99,13 @@ func newDateTimeType(mysqlName string, precision int) AnnotatedDuckType { } func newEnumType(typ sql.EnumType) AnnotatedDuckType { - // TODO: `ENUM` allows `,` and `'` in the values. We need to escape `'`. - typeString := `ENUM('` + strings.Join(typ.Values(), `', '`) + `')` + // For ENUM type, we need to escape single quotes in values + escapedValues := make([]string, len(typ.Values())) + for i, v := range typ.Values() { + // Replace each single quote with two single quotes to escape it + escapedValues[i] = strings.ReplaceAll(v, "'", "''") + } + typeString := `ENUM('` + strings.Join(escapedValues, `', '`) + `')` return AnnotatedDuckType{typeString, MySQLType{Name: "ENUM", Values: typ.Values(), Collation: uint16(typ.Collation())}} } diff --git a/test/bats/mysql/helper.bash b/test/bats/mysql/helper.bash index 93cf48c..8a45554 100644 --- a/test/bats/mysql/helper.bash +++ b/test/bats/mysql/helper.bash @@ -13,4 +13,12 @@ mysql_exec() { mysql_exec_stdin() { mysql -h "$MYSQL_HOST" -P "$MYSQL_PORT" -u "$MYSQL_USER" --raw --batch --skip-column-names --local-infile "$@" +} + +create_temp_file() { + local content="$1" + local tempfile + tempfile="$(mktemp)" + echo -e "$content" > "$tempfile" + echo "$tempfile" } \ No newline at end of file diff --git a/test/bats/mysql/load_data.bats b/test/bats/mysql/load_data.bats new file mode 100644 index 0000000..04a7144 --- /dev/null +++ b/test/bats/mysql/load_data.bats @@ -0,0 +1,53 @@ +#!/usr/bin/env bats +bats_require_minimum_version 1.5.0 + +load helper + +setup_file() { + mysql_exec_stdin <<-'EOF' + CREATE DATABASE load_data_test; + SET GLOBAL local_infile = 1; +EOF +} + +teardown_file() { + mysql_exec_stdin <<-'EOF' + DROP DATABASE IF EXISTS load_data_test; +EOF +} + +@test "Load a TSV file that contains an escaped JSON column" { + skip + mysql_exec_stdin <<-'EOF' + USE load_data_test; + CREATE TABLE translations (code VARCHAR(100), domain VARCHAR(16), translations JSON); + LOAD DATA LOCAL INFILE 'testdata/issue329.tsv' REPLACE INTO TABLE translations CHARACTER SET 'utf8mb4' FIELDS TERMINATED BY ' ' ESCAPED BY '\\' LINES STARTING BY '' TERMINATED BY '\n' (`code`, `domain`, `translations`); +EOF + run -0 mysql_exec 'SELECT COUNT(*) FROM load_data_test.translations' + [ "${output}" = "1" ] +} + +@test "Load a TSV file with date and enum columns" { + local tempfile + tempfile=$(create_temp_file "2025-01-06\t2025-01-06\t2025-01-06\tphprapporten") + + mysql_exec_stdin <<-EOF + USE load_data_test; + CREATE TABLE peildatum ( + datum date DEFAULT NULL, + vanaf date DEFAULT NULL, + tot date DEFAULT NULL, + doel enum('phprapporten','excelrapporten','opslagkosten') CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; + LOAD DATA LOCAL INFILE '${tempfile}' REPLACE INTO TABLE peildatum + CHARACTER SET 'utf8mb4' + FIELDS TERMINATED BY ' ' ESCAPED BY '\\\\' + LINES STARTING BY '' TERMINATED BY '\n' + (datum, vanaf, tot, doel); +EOF + + run -0 mysql_exec 'SELECT * FROM load_data_test.peildatum' + [ "${output}" = "2025-01-06 2025-01-06 2025-01-06 phprapporten" ] + + rm "$tempfile" +} \ No newline at end of file diff --git a/test/bats/mysql/load_json_column.bats b/test/bats/mysql/load_json_column.bats deleted file mode 100644 index bdeb652..0000000 --- a/test/bats/mysql/load_json_column.bats +++ /dev/null @@ -1,17 +0,0 @@ -#!/usr/bin/env bats -bats_require_minimum_version 1.5.0 - -load helper - -@test "Load a TSV file that contains an escaped JSON column" { - skip - mysql_exec_stdin <<-'EOF' - CREATE DATABASE load_json_column; - USE load_json_column; - CREATE TABLE translations (code VARCHAR(100), domain VARCHAR(16), translations JSON); - SET GLOBAL local_infile = 1; - LOAD DATA LOCAL INFILE 'testdata/issue329.tsv' REPLACE INTO TABLE `load_json_column`.`translations` CHARACTER SET 'utf8mb4' FIELDS TERMINATED BY ' ' ESCAPED BY '\\' LINES STARTING BY '' TERMINATED BY '\n' (`code`, `domain`, `translations`); -EOF - run -0 mysql_exec 'SELECT COUNT(*) FROM `load_json_column`.`translations`' - [ "${output}" = "1" ] -} \ No newline at end of file