Skip to content

Commit

Permalink
feat: initial AUTO_INCREMENT support for MySQL (#292)
Browse files Browse the repository at this point in the history
* chore: upgrade go-mysql-server
* feat: initial AUTO_INCREMENT support for MySQL
* Remove supported query from disallow list
  • Loading branch information
fanyang01 authored Dec 17, 2024
1 parent 1e155bc commit 708d026
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 40 deletions.
2 changes: 0 additions & 2 deletions catalog/comment.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ type Comment[T any] struct {
Meta T `json:"meta,omitempty"` // extra information, e.g. the original MySQL column type, etc.
}

const ManagedCommentPrefix = "base64:"

func DecodeComment[T any](encodedOrRawText string) *Comment[T] {
if !strings.HasPrefix(encodedOrRawText, ManagedCommentPrefix) {
return NewComment[T](encodedOrRawText)
Expand Down
8 changes: 8 additions & 0 deletions catalog/constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package catalog

const (
// ManagedCommentPrefix is the prefix for comments that are managed by the catalog.
ManagedCommentPrefix = "base64:"
// SequenceNamePrefix is the prefix for sequence names that are managed by the catalog.
SequenceNamePrefix = "__sys_table_seq_"
)
52 changes: 48 additions & 4 deletions catalog/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"github.com/apecloud/myduckserver/configuration"
"github.com/apecloud/myduckserver/mycontext"
"github.com/dolthub/go-mysql-server/sql"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
)

type Database struct {
Expand Down Expand Up @@ -120,6 +122,8 @@ func (d *Database) createAllTable(ctx *sql.Context, name string, schema sql.Prim
fullTableName = FullTableName(d.catalog, d.name, name)
}

var sequenceName, fullSequenceName string

for _, col := range schema.Schema {
typ, err := DuckdbDataType(col.Type)
if err != nil {
Expand All @@ -133,11 +137,30 @@ func (d *Database) createAllTable(ctx *sql.Context, name string, schema sql.Prim
}

if col.Default != nil {
columnDefault, err := typ.mysql.withDefault(col.Default.String())
typ.mysql.Default = col.Default.String()
defaultExpr, err := parseDefaultValue(typ.mysql.Default)
if err != nil {
return err
}
colDef += " DEFAULT " + defaultExpr
} else if col.AutoIncrement {
typ.mysql.AutoIncrement = true

// Generate a random sequence name.
// TODO(fan): Drop the sequence when the table is dropped or the column is removed.
uuid, err := uuid.NewRandom()
if err != nil {
return err
}
colDef += " DEFAULT " + columnDefault
sequenceName = SequenceNamePrefix + uuid.String()
if temporary {
fullSequenceName = `temp.main."` + sequenceName + `"`
} else {
fullSequenceName = InternalSchemas.SYS.Schema + `."` + sequenceName + `"`
}

defaultExpr := `nextval('` + fullSequenceName + `')`
colDef += " DEFAULT " + defaultExpr
}

columns = append(columns, colDef)
Expand All @@ -158,6 +181,20 @@ func (d *Database) createAllTable(ctx *sql.Context, name string, schema sql.Prim
}

var b strings.Builder
b.Grow(256)

if sequenceName != "" {
b.WriteString(`CREATE `)
if temporary {
b.WriteString(`TEMP SEQUENCE "`)
b.WriteString(sequenceName)
b.WriteString(`"`)
} else {
b.WriteString(`SEQUENCE `)
b.WriteString(fullSequenceName)
}
b.WriteString(`;`)
}

if temporary {
b.WriteString(fmt.Sprintf(`CREATE TEMP TABLE %s (%s`, name, strings.Join(columns, ", ")))
Expand All @@ -180,10 +217,11 @@ func (d *Database) createAllTable(ctx *sql.Context, name string, schema sql.Prim
b.WriteString(")")

// Add comment to the table
info := ExtraTableInfo{schema.PkOrdinals, withoutIndex, fullSequenceName}
b.WriteString(fmt.Sprintf(
"; COMMENT ON TABLE %s IS '%s'",
fullTableName,
NewCommentWithMeta(comment, ExtraTableInfo{schema.PkOrdinals, withoutIndex}).Encode(),
NewCommentWithMeta(comment, info).Encode(),
))

// Add column comments
Expand All @@ -192,7 +230,13 @@ func (d *Database) createAllTable(ctx *sql.Context, name string, schema sql.Prim
b.WriteString(s)
}

_, err := adapter.Exec(ctx, b.String())
ddl := b.String()

if logger := ctx.GetLogger(); logger.Logger.GetLevel() >= logrus.DebugLevel {
logger.WithField("DuckSQL", ddl).Debug("Executing DDL")
}

_, err := adapter.Exec(ctx, ddl)
if err != nil {
if IsDuckDBTableAlreadyExistsError(err) {
return sql.ErrTableAlreadyExists.New(name)
Expand Down
6 changes: 5 additions & 1 deletion catalog/internal_schemas.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@ type InternalSchema struct {
}

var InternalSchemas = struct {
SYS InternalSchema
MySQL InternalSchema
}{
SYS: InternalSchema{
Schema: "__sys__",
},
MySQL: InternalSchema{
Schema: "mysql",
},
}

var internalSchemas = []InternalSchema{
InternalSchemas.MySQL,
InternalSchemas.MySQL, InternalSchemas.SYS,
}
99 changes: 95 additions & 4 deletions catalog/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package catalog
import (
stdsql "database/sql"
"fmt"
"strconv"
"strings"
"sync"

Expand All @@ -25,6 +26,7 @@ type Table struct {
type ExtraTableInfo struct {
PkOrdinals []int
Replicated bool
Sequence string
}

type ColumnInfo struct {
Expand All @@ -51,6 +53,7 @@ var _ sql.DeletableTable = (*Table)(nil)
var _ sql.TruncateableTable = (*Table)(nil)
var _ sql.ReplaceableTable = (*Table)(nil)
var _ sql.CommentedTable = (*Table)(nil)
var _ sql.AutoIncrementTable = (*Table)(nil)

func NewTable(name string, db *Database) *Table {
return &Table{
Expand Down Expand Up @@ -132,6 +135,7 @@ func getPKSchema(ctx *sql.Context, catalogName, dbName, tableName string) sql.Pr
Source: tableName,
DatabaseSource: dbName,
Default: defaultValue,
AutoIncrement: decodedComment.Meta.AutoIncrement,
Comment: decodedComment.Text,
}

Expand Down Expand Up @@ -201,11 +205,12 @@ func (t *Table) AddColumn(ctx *sql.Context, column *sql.Column, order *sql.Colum
}

if column.Default != nil {
columnDefault, err := typ.mysql.withDefault(column.Default.String())
typ.mysql.Default = column.Default.String()
defaultExpr, err := parseDefaultValue(typ.mysql.Default)
if err != nil {
return err
}
sql += fmt.Sprintf(" DEFAULT %s", columnDefault)
sql += " DEFAULT " + defaultExpr
}

// add comment
Expand Down Expand Up @@ -257,11 +262,12 @@ func (t *Table) ModifyColumn(ctx *sql.Context, columnName string, column *sql.Co
}

if column.Default != nil {
columnDefault, err := typ.mysql.withDefault(column.Default.String())
typ.mysql.Default = column.Default.String()
defaultExpr, err := parseDefaultValue(typ.mysql.Default)
if err != nil {
return err
}
sqls = append(sqls, fmt.Sprintf(`%s SET DEFAULT %s`, baseSQL, columnDefault))
sqls = append(sqls, fmt.Sprintf(`%s SET DEFAULT %s`, baseSQL, defaultExpr))
} else {
sqls = append(sqls, fmt.Sprintf(`%s DROP DEFAULT`, baseSQL))
}
Expand Down Expand Up @@ -578,3 +584,88 @@ func queryColumns(ctx *sql.Context, catalogName, schemaName, tableName string) (
func (t *IndexedTable) LookupPartitions(ctx *sql.Context, lookup sql.IndexLookup) (sql.PartitionIter, error) {
return nil, fmt.Errorf("unimplemented(LookupPartitions) (table: %s, query: %s)", t.name, ctx.Query())
}

// PeekNextAutoIncrementValue implements sql.AutoIncrementTable.
func (t *Table) PeekNextAutoIncrementValue(ctx *sql.Context) (uint64, error) {
if t.comment.Meta.Sequence == "" {
return 0, sql.ErrNoAutoIncrementCol
}

// For PeekNextAutoIncrementValue, we want to see what the next value would be
// without actually incrementing. We can do this by getting currval + 1.
var val uint64
err := adapter.QueryRowCatalog(ctx, `SELECT currval('`+t.comment.Meta.Sequence+`') + 1`).Scan(&val)
if err != nil {
return 0, ErrDuckDB.New(err)
}

return val, nil
}

// GetNextAutoIncrementValue implements sql.AutoIncrementTable.
func (t *Table) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{}) (uint64, error) {
if t.comment.Meta.Sequence == "" {
return 0, sql.ErrNoAutoIncrementCol
}

// If insertVal is provided and greater than current sequence value, update sequence
if insertVal != nil {
var start uint64
switch v := insertVal.(type) {
case uint64:
start = v
case int64:
if v > 0 {
start = uint64(v)
}
}
if start > 0 {
err := t.setAutoIncrementValue(ctx, start)
if err != nil {
return 0, err
}
return start, nil
}
}

// Get next value from sequence
var val uint64
err := adapter.QueryRowCatalog(ctx, `SELECT nextval('`+t.comment.Meta.Sequence+`')`).Scan(&val)
if err != nil {
return 0, ErrDuckDB.New(err)
}

return val, nil
}

// AutoIncrementSetter implements sql.AutoIncrementTable.
func (t *Table) AutoIncrementSetter(ctx *sql.Context) sql.AutoIncrementSetter {
if t.comment.Meta.Sequence == "" {
return nil
}
return &autoIncrementSetter{t: t}
}

// setAutoIncrementValue is a helper function to update the sequence value
func (t *Table) setAutoIncrementValue(ctx *sql.Context, value uint64) error {
_, err := adapter.ExecCatalog(ctx, `CREATE OR REPLACE SEQUENCE `+t.comment.Meta.Sequence+` START WITH `+strconv.FormatUint(value, 10))
return err
}

// autoIncrementSetter implements the AutoIncrementSetter interface
type autoIncrementSetter struct {
t *Table
}

func (s *autoIncrementSetter) SetAutoIncrementValue(ctx *sql.Context, value uint64) error {
return s.t.setAutoIncrementValue(ctx, value)
}

func (s *autoIncrementSetter) Close(ctx *sql.Context) error {
return nil
}

func (s *autoIncrementSetter) AcquireAutoIncrementLock(ctx *sql.Context) (func(), error) {
// DuckDB handles sequence synchronization internally
return func() {}, nil
}
24 changes: 12 additions & 12 deletions catalog/type_mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,16 @@ func (t AnnotatedDuckType) MySQL() MySQLType {
}

type MySQLType struct {
Name string
Length uint32 `json:",omitempty"`
Precision uint8 `json:",omitempty"`
Scale uint8 `json:",omitempty"`
Unsigned bool `json:",omitempty"`
Display uint8 `json:",omitempty"` // Display width for integer types
Collation uint16 `json:",omitempty"` // For string types
Values []string `json:",omitempty"` // For ENUM and SET
Default string `json:",omitempty"` // Default value of column
Name string
Length uint32 `json:",omitempty"`
Precision uint8 `json:",omitempty"`
Scale uint8 `json:",omitempty"`
Unsigned bool `json:",omitempty"`
Display uint8 `json:",omitempty"` // Display width for integer types
Collation uint16 `json:",omitempty"` // For string types
Values []string `json:",omitempty"` // For ENUM and SET
Default string `json:",omitempty"` // Default value of column
AutoIncrement bool `json:",omitempty"` // Auto increment flag
}

func newCommonType(name string) AnnotatedDuckType {
Expand Down Expand Up @@ -316,9 +317,8 @@ func mysqlDataType(duckType AnnotatedDuckType, numericPrecision uint8, numericSc
}
}

func (typ *MySQLType) withDefault(defaultValue string) (string, error) {
typ.Default = defaultValue
parsed, err := sqlparser.Parse(fmt.Sprintf("SELECT %s", defaultValue))
func parseDefaultValue(defaultValue string) (string, error) {
parsed, err := sqlparser.Parse("SELECT " + defaultValue)
if err != nil {
return "", err
}
Expand Down
14 changes: 7 additions & 7 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@ require (
github.com/aws/aws-sdk-go-v2/credentials v1.17.30
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.15
github.com/aws/aws-sdk-go-v2/service/s3 v1.60.1
github.com/aws/smithy-go v1.20.4
github.com/cockroachdb/apd/v3 v3.2.1
github.com/cockroachdb/cockroachdb-parser v0.23.2
github.com/cockroachdb/errors v1.9.0
github.com/dolthub/doltgresql v0.13.0
github.com/dolthub/go-mysql-server v0.18.2-0.20241127000145-a1809677932e
github.com/dolthub/vitess v0.0.0-20241126223332-cd8f828f26ac
github.com/dolthub/go-mysql-server v0.18.2-0.20241215013221-68ab2c34608f
github.com/dolthub/vitess v0.0.0-20241211024425-b00987f7ba54
github.com/go-sql-driver/mysql v1.8.1
github.com/google/uuid v1.6.0
github.com/jackc/pglogrepl v0.0.0-20240307033717-828fbfe908e9
Expand All @@ -37,8 +36,8 @@ require (
)

replace (
github.com/dolthub/go-mysql-server v0.18.2-0.20241127000145-a1809677932e => github.com/apecloud/go-mysql-server v0.0.0-20241127073935-94c04f2f750d
github.com/dolthub/vitess v0.0.0-20241126223332-cd8f828f26ac => github.com/apecloud/dolt-vitess v0.0.0-20241127063501-5c7c985f0e57
github.com/dolthub/go-mysql-server v0.18.2-0.20241215013221-68ab2c34608f => github.com/apecloud/go-mysql-server v0.0.0-20241217030038-1ec40b6e7e7f
github.com/dolthub/vitess v0.0.0-20241211024425-b00987f7ba54 => github.com/apecloud/dolt-vitess v0.0.0-20241217030333-e641a5d88d61
github.com/marcboeker/go-duckdb v1.8.3 => github.com/apecloud/go-duckdb v0.0.0-20241127093618-047c1a233928
)

Expand All @@ -59,6 +58,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/sso v1.22.5 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.30.5 // indirect
github.com/aws/smithy-go v1.20.4 // indirect
github.com/bazelbuild/rules_go v0.46.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/biogo/store v0.0.0-20201120204734-aad293a2328f // indirect
Expand All @@ -69,7 +69,7 @@ require (
github.com/dave/dst v0.27.2 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 // indirect
github.com/dolthub/go-icu-regex v0.0.0-20240916130659-0118adc6b662 // indirect
github.com/dolthub/go-icu-regex v0.0.0-20241215010122-db690dd53c90 // indirect
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/getsentry/sentry-go v0.12.0 // indirect
Expand Down Expand Up @@ -111,7 +111,7 @@ require (
github.com/rs/xid v1.5.0 // indirect
github.com/sasha-s/go-deadlock v0.3.1 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/tetratelabs/wazero v1.1.0 // indirect
github.com/tetratelabs/wazero v1.8.2 // indirect
github.com/twpayne/go-geom v1.4.1 // indirect
github.com/twpayne/go-kml v1.5.2 // indirect
github.com/xdg-go/stringprep v1.0.4 // indirect
Expand Down
Loading

0 comments on commit 708d026

Please sign in to comment.