Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: LOAD DATA into enum columns #348

Merged
merged 3 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row
ctx.GetLogger().WithFields(logrus.Fields{
"Query": ctx.Query(),
"NodeType": fmt.Sprintf("%T", n),
}).Trace("Building node:", n)
}).Traceln("Building node:", n)

// TODO; find a better way to fallback to the base builder
switch n.(type) {
Expand Down
34 changes: 20 additions & 14 deletions backend/loaddata.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/vitess/go/vt/proto/query"
)

const isUnixSystem = runtime.GOOS == "linux" ||
Expand Down Expand Up @@ -244,14 +245,8 @@ func columnTypeHints(b *strings.Builder, dst sql.Table, schema sql.Schema, colNa
if i > 0 {
b.WriteString(", ")
}
b.WriteString(catalog.QuoteIdentifierANSI(col.Name))
b.WriteString(": ")
if dt, err := catalog.DuckdbDataType(col.Type); err != nil {
if err := columnTypeHint(b, col); err != nil {
return err
} else {
b.WriteString(`'`)
b.WriteString(dt.Name())
b.WriteString(`'`)
}
}
b.WriteString("}")
Expand All @@ -262,25 +257,36 @@ func columnTypeHints(b *strings.Builder, dst sql.Table, schema sql.Schema, colNa
if i > 0 {
b.WriteString(", ")
}
b.WriteString(catalog.QuoteIdentifierANSI(col))
b.WriteString(": ")
idx := schema.IndexOf(col, dst.Name()) // O(n^2) but n := # of columns is usually small
if idx < 0 {
return sql.ErrTableColumnNotFound.New(dst.Name(), col)
}
if dt, err := catalog.DuckdbDataType(schema[idx].Type); err != nil {
if err := columnTypeHint(b, schema[idx]); err != nil {
return err
} else {
b.WriteString(`'`)
b.WriteString(dt.Name())
b.WriteString(`'`)
}
}

b.WriteString("}")
return nil
}

func columnTypeHint(b *strings.Builder, col *sql.Column) error {
b.WriteString(catalog.QuoteIdentifierANSI(col.Name))
b.WriteString(": ")
if dt, err := catalog.DuckdbDataType(col.Type); err != nil {
return err
} else {
b.WriteString(`'`)
if col.Type.Type() == query.Type_ENUM {
b.WriteString(`VARCHAR`)
} else {
b.WriteString(dt.Name())
}
b.WriteString(`'`)
}
return nil
}

// isUnderSecureFileDir ensures that fileStr is under secureFileDir or a subdirectory of secureFileDir, errors otherwise
// Copied from https://github.com/dolthub/go-mysql-server/blob/main/sql/rowexec/rel.go
func isUnderSecureFileDir(secureFileDir interface{}, fileStr string) error {
Expand Down
9 changes: 7 additions & 2 deletions catalog/type_mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,13 @@ func newDateTimeType(mysqlName string, precision int) AnnotatedDuckType {
}

func newEnumType(typ sql.EnumType) AnnotatedDuckType {
// TODO: `ENUM` allows `,` and `'` in the values. We need to escape `'`.
typeString := `ENUM('` + strings.Join(typ.Values(), `', '`) + `')`
// For ENUM type, we need to escape single quotes in values
escapedValues := make([]string, len(typ.Values()))
for i, v := range typ.Values() {
// Replace each single quote with two single quotes to escape it
escapedValues[i] = strings.ReplaceAll(v, "'", "''")
}
typeString := `ENUM('` + strings.Join(escapedValues, `', '`) + `')`
return AnnotatedDuckType{typeString, MySQLType{Name: "ENUM", Values: typ.Values(), Collation: uint16(typ.Collation())}}
}

Expand Down
8 changes: 8 additions & 0 deletions test/bats/mysql/helper.bash
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,12 @@ mysql_exec() {

mysql_exec_stdin() {
mysql -h "$MYSQL_HOST" -P "$MYSQL_PORT" -u "$MYSQL_USER" --raw --batch --skip-column-names --local-infile "$@"
}

create_temp_file() {
local content="$1"
local tempfile
tempfile="$(mktemp)"
echo -e "$content" > "$tempfile"
echo "$tempfile"
}
53 changes: 53 additions & 0 deletions test/bats/mysql/load_data.bats
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/usr/bin/env bats
bats_require_minimum_version 1.5.0

load helper

setup_file() {
mysql_exec_stdin <<-'EOF'
CREATE DATABASE load_data_test;
SET GLOBAL local_infile = 1;
EOF
}

teardown_file() {
mysql_exec_stdin <<-'EOF'
DROP DATABASE IF EXISTS load_data_test;
EOF
}

@test "Load a TSV file that contains an escaped JSON column" {
skip
mysql_exec_stdin <<-'EOF'
USE load_data_test;
CREATE TABLE translations (code VARCHAR(100), domain VARCHAR(16), translations JSON);
LOAD DATA LOCAL INFILE 'testdata/issue329.tsv' REPLACE INTO TABLE translations CHARACTER SET 'utf8mb4' FIELDS TERMINATED BY ' ' ESCAPED BY '\\' LINES STARTING BY '' TERMINATED BY '\n' (`code`, `domain`, `translations`);
EOF
run -0 mysql_exec 'SELECT COUNT(*) FROM load_data_test.translations'
[ "${output}" = "1" ]
}

@test "Load a TSV file with date and enum columns" {
local tempfile
tempfile=$(create_temp_file "2025-01-06\t2025-01-06\t2025-01-06\tphprapporten")

mysql_exec_stdin <<-EOF
USE load_data_test;
CREATE TABLE peildatum (
datum date DEFAULT NULL,
vanaf date DEFAULT NULL,
tot date DEFAULT NULL,
doel enum('phprapporten','excelrapporten','opslagkosten') CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;
LOAD DATA LOCAL INFILE '${tempfile}' REPLACE INTO TABLE peildatum
CHARACTER SET 'utf8mb4'
FIELDS TERMINATED BY ' ' ESCAPED BY '\\\\'
LINES STARTING BY '' TERMINATED BY '\n'
(datum, vanaf, tot, doel);
EOF

run -0 mysql_exec 'SELECT * FROM load_data_test.peildatum'
[ "${output}" = "2025-01-06 2025-01-06 2025-01-06 phprapporten" ]

rm "$tempfile"
}
17 changes: 0 additions & 17 deletions test/bats/mysql/load_json_column.bats

This file was deleted.

Loading