diff --git a/README.md b/README.md index e5b546c6f..68d17b872 100644 --- a/README.md +++ b/README.md @@ -591,6 +591,13 @@ dsn: mongodb://mongoadmin:secret@localhost:27017/test dsn: mongodb://mongoadmin:secret@localhost:27017/test?sampleSize=20 ``` +If a field has multiple types, the `multipleFieldType` query can be used to list all the types. + +``` yaml +# .tbls.yml +dsn: mongodb://mongoadmin:secret@localhost:27017/test?sampleSize=20&multipleFieldType=true +``` + **JSON:** The JSON file output by the `tbls out -t json` command can be read as a datasource. diff --git a/datasource/mongo.go b/datasource/mongo.go index 716ce7333..1df1af719 100644 --- a/datasource/mongo.go +++ b/datasource/mongo.go @@ -14,7 +14,10 @@ import ( "go.mongodb.org/mongo-driver/mongo/readpref" ) -const defaultSampleSize = 1000 +const ( + defaultSampleSize = 1000 + defaultMultipleFieldType = false +) // AnalyzeMongodb analyze `mongodb://` func AnalyzeMongodb(urlstr string) (*schema.Schema, error) { @@ -50,7 +53,11 @@ func AnalyzeMongodb(urlstr string) (*schema.Schema, error) { if err != nil { sampleSize = defaultSampleSize } - driver, err := mongodb.New(ctx, client, dbName, sampleSize) + multipleFieldType, err := strconv.ParseBool(values.Get("multipleFieldType")) + if err != nil { + multipleFieldType = defaultMultipleFieldType + } + driver, err := mongodb.New(ctx, client, dbName, sampleSize, multipleFieldType) if err != nil { return s, err } diff --git a/drivers/mongodb/mongodb.go b/drivers/mongodb/mongodb.go index cfededa33..d822004dd 100644 --- a/drivers/mongodb/mongodb.go +++ b/drivers/mongodb/mongodb.go @@ -4,7 +4,9 @@ import ( "context" "database/sql" "fmt" + "slices" "sort" + "strings" "github.com/k1LoW/tbls/dict" "github.com/k1LoW/tbls/schema" @@ -14,19 +16,23 @@ import ( "go.mongodb.org/mongo-driver/mongo" ) +const columnTypeSeparator = "," + type Mongodb struct { - ctx context.Context - client *mongo.Client - dbName string - sampleSize int64 + ctx context.Context + client *mongo.Client + dbName string + sampleSize int64 + multipleFieldType bool } -func New(ctx context.Context, client *mongo.Client, dbName string, sampleSize int64) (*Mongodb, error) { +func New(ctx context.Context, client *mongo.Client, dbName string, sampleSize int64, multipleFieldType bool) (*Mongodb, error) { return &Mongodb{ - ctx: ctx, - client: client, - dbName: dbName, - sampleSize: sampleSize, + ctx: ctx, + client: client, + dbName: dbName, + sampleSize: sampleSize, + multipleFieldType: multipleFieldType, }, nil } @@ -120,6 +126,10 @@ func (d *Mongodb) listFields(collection *mongo.Collection) ([]*schema.Column, er if !columnInColumns(column, columns) { columns = append(columns, column) } + + if d.multipleFieldType { + columns = addColumnType(columns, key, valueType) + } } } for _, col := range columns { @@ -145,6 +155,25 @@ func columnInColumns(a *schema.Column, list []*schema.Column) bool { return false } +// addColumnType adds a new type to the specified column +func addColumnType(list []*schema.Column, columnName, valueType string) []*schema.Column { + columns := make([]*schema.Column, len(list)) + + for i, col := range list { + column := *col + if column.Name == columnName { + types := append(strings.Split(column.Type, columnTypeSeparator), valueType) + slices.Sort(types) + uniqTypes := slices.Compact(types) + column.Type = strings.Join(uniqTypes, columnTypeSeparator) + } + + columns[i] = &column + } + + return columns +} + func (d *Mongodb) listIndexes(collection *mongo.Collection) ([]*schema.Index, error) { indexes := []*schema.Index{} indexSpec, err := collection.Indexes().ListSpecifications(d.ctx) diff --git a/drivers/mongodb/mongodb_test.go b/drivers/mongodb/mongodb_test.go index d1acfd5d0..8b00716a3 100644 --- a/drivers/mongodb/mongodb_test.go +++ b/drivers/mongodb/mongodb_test.go @@ -5,6 +5,7 @@ package mongodb import ( "context" "net/url" + "reflect" "strings" "testing" @@ -38,7 +39,7 @@ func TestAnalyze(t *testing.T) { s := &schema.Schema{ Name: "MongoDB local `docker-mongo-sample-datasets`", } - driver, err := New(ctx, client, dbName, 10) + driver, err := New(ctx, client, dbName, 10, false) if err != nil { t.Errorf("%v", err) } @@ -55,3 +56,57 @@ func TestAnalyze(t *testing.T) { t.Errorf("got not empty string.") } } + +func Test_addColumnType(t *testing.T) { + columns := []*schema.Column{ + { + Name: "username", + Type: "string", + }, + { + Name: "age", + Type: "int", + }, + } + + tests := []struct { + name string + list []*schema.Column + columnName string + valueType string + want []*schema.Column + }{ + { + name: "Existing types are not added", + list: columns, + columnName: "username", + valueType: "string", + want: columns, + }, + { + name: "New types are added with comma separation", + list: columns, + columnName: "age", + valueType: "string", + want: []*schema.Column{ + { + Name: "username", + Type: "string", + }, + { + Name: "age", + Type: "int,string", + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got := addColumnType(test.list, test.columnName, test.valueType) + if !reflect.DeepEqual(got, test.want) { + t.Errorf("got %v\nwant %v", got, test.want) + } + }) + } +}