Skip to content

Commit

Permalink
Merge pull request #540 from mrtc0/support-multitype-field-on-mongodb
Browse files Browse the repository at this point in the history
feat: [MongoDB] Support multiple type field
  • Loading branch information
k1LoW authored Nov 23, 2023
2 parents e7ff859 + aeac717 commit 0865747
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 12 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,13 @@ dsn: mongodb://mongoadmin:secret@localhost:27017/test
dsn: mongodb://mongoadmin:secret@localhost:27017/test?sampleSize=20
```

If a field has multiple types, the `multipleFieldType` query can be used to list all the types.

``` yaml
# .tbls.yml
dsn: mongodb://mongoadmin:secret@localhost:27017/test?sampleSize=20&multipleFieldType=true
```

**JSON:**

The JSON file output by the `tbls out -t json` command can be read as a datasource.
Expand Down
11 changes: 9 additions & 2 deletions datasource/mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ import (
"go.mongodb.org/mongo-driver/mongo/readpref"
)

const defaultSampleSize = 1000
const (
defaultSampleSize = 1000
defaultMultipleFieldType = false
)

// AnalyzeMongodb analyze `mongodb://`
func AnalyzeMongodb(urlstr string) (*schema.Schema, error) {
Expand Down Expand Up @@ -50,7 +53,11 @@ func AnalyzeMongodb(urlstr string) (*schema.Schema, error) {
if err != nil {
sampleSize = defaultSampleSize
}
driver, err := mongodb.New(ctx, client, dbName, sampleSize)
multipleFieldType, err := strconv.ParseBool(values.Get("multipleFieldType"))
if err != nil {
multipleFieldType = defaultMultipleFieldType
}
driver, err := mongodb.New(ctx, client, dbName, sampleSize, multipleFieldType)
if err != nil {
return s, err
}
Expand Down
47 changes: 38 additions & 9 deletions drivers/mongodb/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"context"
"database/sql"
"fmt"
"slices"
"sort"
"strings"

"github.com/k1LoW/tbls/dict"
"github.com/k1LoW/tbls/schema"
Expand All @@ -14,19 +16,23 @@ import (
"go.mongodb.org/mongo-driver/mongo"
)

const columnTypeSeparator = ","

type Mongodb struct {
ctx context.Context
client *mongo.Client
dbName string
sampleSize int64
ctx context.Context
client *mongo.Client
dbName string
sampleSize int64
multipleFieldType bool
}

func New(ctx context.Context, client *mongo.Client, dbName string, sampleSize int64) (*Mongodb, error) {
func New(ctx context.Context, client *mongo.Client, dbName string, sampleSize int64, multipleFieldType bool) (*Mongodb, error) {
return &Mongodb{
ctx: ctx,
client: client,
dbName: dbName,
sampleSize: sampleSize,
ctx: ctx,
client: client,
dbName: dbName,
sampleSize: sampleSize,
multipleFieldType: multipleFieldType,
}, nil
}

Expand Down Expand Up @@ -120,6 +126,10 @@ func (d *Mongodb) listFields(collection *mongo.Collection) ([]*schema.Column, er
if !columnInColumns(column, columns) {
columns = append(columns, column)
}

if d.multipleFieldType {
columns = addColumnType(columns, key, valueType)
}
}
}
for _, col := range columns {
Expand All @@ -145,6 +155,25 @@ func columnInColumns(a *schema.Column, list []*schema.Column) bool {
return false
}

// addColumnType adds a new type to the specified column
func addColumnType(list []*schema.Column, columnName, valueType string) []*schema.Column {
columns := make([]*schema.Column, len(list))

for i, col := range list {
column := *col
if column.Name == columnName {
types := append(strings.Split(column.Type, columnTypeSeparator), valueType)
slices.Sort(types)
uniqTypes := slices.Compact(types)
column.Type = strings.Join(uniqTypes, columnTypeSeparator)
}

columns[i] = &column
}

return columns
}

func (d *Mongodb) listIndexes(collection *mongo.Collection) ([]*schema.Index, error) {
indexes := []*schema.Index{}
indexSpec, err := collection.Indexes().ListSpecifications(d.ctx)
Expand Down
57 changes: 56 additions & 1 deletion drivers/mongodb/mongodb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package mongodb
import (
"context"
"net/url"
"reflect"
"strings"
"testing"

Expand Down Expand Up @@ -38,7 +39,7 @@ func TestAnalyze(t *testing.T) {
s := &schema.Schema{
Name: "MongoDB local `docker-mongo-sample-datasets`",
}
driver, err := New(ctx, client, dbName, 10)
driver, err := New(ctx, client, dbName, 10, false)
if err != nil {
t.Errorf("%v", err)
}
Expand All @@ -55,3 +56,57 @@ func TestAnalyze(t *testing.T) {
t.Errorf("got not empty string.")
}
}

func Test_addColumnType(t *testing.T) {
columns := []*schema.Column{
{
Name: "username",
Type: "string",
},
{
Name: "age",
Type: "int",
},
}

tests := []struct {
name string
list []*schema.Column
columnName string
valueType string
want []*schema.Column
}{
{
name: "Existing types are not added",
list: columns,
columnName: "username",
valueType: "string",
want: columns,
},
{
name: "New types are added with comma separation",
list: columns,
columnName: "age",
valueType: "string",
want: []*schema.Column{
{
Name: "username",
Type: "string",
},
{
Name: "age",
Type: "int,string",
},
},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got := addColumnType(test.list, test.columnName, test.valueType)
if !reflect.DeepEqual(got, test.want) {
t.Errorf("got %v\nwant %v", got, test.want)
}
})
}
}

0 comments on commit 0865747

Please sign in to comment.