Skip to content

Commit

Permalink
Spike out duckdb support
Browse files Browse the repository at this point in the history
  • Loading branch information
gwenwindflower committed Apr 14, 2024
1 parent 6032a44 commit 1501f11
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 43 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@
_Disclaimer: This project is not affiliated with dbt Labs in any way. It is a personal project and is not officially supported by dbt Labs. I work at dbt Labs, but I develop this project in my own time._

## Supported warehouses

- [x] BigQuery
- [x] Snowflake
- [ ] Redshift
- [ ] Databricks
- [ ] Postgres
- [ ] DuckDB

## Installation

For the time being this project ideally requires `go`. When I've gotten test coverage up to a reasonable level and covered another dbt adapter or two, I'll set up a Homebrew tap. In the meantime, you can install it with the following command:
Expand Down
37 changes: 31 additions & 6 deletions forms.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type FormResponse struct {
Schema string
Project string
Dataset string
Path string
BuildDir string
GenerateDescriptions bool
GroqKeyEnvVar string
Expand Down Expand Up @@ -93,6 +94,7 @@ You'll need:
Options(
huh.NewOption("Snowflake", "snowflake"),
huh.NewOption("BigQuery", "bigquery"),
huh.NewOption("DuckDB", "duckdb"),
).
Value(&formResponse.Warehouse),
),
Expand Down Expand Up @@ -124,6 +126,16 @@ You'll need:
Value(&formResponse.Dataset).Placeholder("mirkwood"),
),
)
duckdb_form := huh.NewForm(
huh.NewGroup(
huh.NewInput().Title("What is the path to your DuckDB database?").
Value(&formResponse.Path).Placeholder("/path/to/duckdb.db"),
huh.NewInput().Title("What is the DuckDB database you want to generate?").
Value(&formResponse.Database).Placeholder("duckdb"),
huh.NewInput().Title("What is the schema you want to generate?").
Value(&formResponse.Schema).Placeholder("raw"),
),
)
llm_form := huh.NewForm(
huh.NewGroup(
huh.NewInput().
Expand Down Expand Up @@ -159,6 +171,7 @@ tbd will overwrite any existing files of the same name.`),
warehouse_form.WithTheme(huh.ThemeCatppuccin())
snowflake_form.WithTheme(huh.ThemeCatppuccin())
bigquery_form.WithTheme(huh.ThemeCatppuccin())
duckdb_form.WithTheme(huh.ThemeCatppuccin())
llm_form.WithTheme(huh.ThemeCatppuccin())
dir_form.WithTheme(huh.ThemeCatppuccin())
confirm_form.WithTheme(huh.ThemeCatppuccin())
Expand All @@ -168,24 +181,36 @@ tbd will overwrite any existing files of the same name.`),
}
if formResponse.UseDbtProfile {
err = dbt_form.Run()
if err != nil {
log.Fatalf("Error running dbt form %v\n", err)
}
} else {
err = warehouse_form.Run()
if err != nil {
log.Fatalf("Error running warehouse form %v\n", err)
}
switch formResponse.Warehouse {
case "snowflake":
err = snowflake_form.Run()
if err != nil {
log.Fatalf("Error running snowflake form %v\n", err)
}
case "bigquery":
err = bigquery_form.Run()
if err != nil {
log.Fatalf("Error running bigquery form %v\n", err)
{
err = bigquery_form.Run()
if err != nil {
log.Fatalf("Error running bigquery form %v\n", err)
}
}
case "duckdb":
{
err = duckdb_form.Run()
if err != nil {
log.Fatalf("Error running duckdb form %v\n", err)
}
}
}
}
if err != nil {
log.Fatalf("Error running connection details form %v\n", err)
}
if formResponse.GenerateDescriptions {
err = llm_form.Run()
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ require (
github.com/klauspost/compress v1.17.7 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/marcboeker/go-duckdb v1.6.3 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/mtibben/percent v0.2.1 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,17 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/marcboeker/go-duckdb v1.6.3 h1:5qRxB3BosFXRjfQWNP0OOqEQFXllo6o7fHGrNA7NSuM=
github.com/marcboeker/go-duckdb v1.6.3/go.mod h1:WtWeqqhZoTke/Nbd7V9lnBx7I2/A/q0SAq/urGzPCMs=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/mtibben/percent v0.2.1 h1:5gssi8Nqo8QU/r2pynCm+hBQHpkB/uNK7BJCFogWdzs=
github.com/mtibben/percent v0.2.1/go.mod h1:KG9uO+SZkUp+VkRHsCdYQV3XSZrrSpR3O9ibNBTZrns=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
Expand Down
16 changes: 16 additions & 0 deletions set_connection_details.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ func SetConnectionDetails(formResponse FormResponse) shared.ConnectionDetails {
Dataset: profile.Outputs[formResponse.DbtProfileOutput].Dataset,
}
}
case "duckdb":
{
connectionDetails = shared.ConnectionDetails{
ConnType: profile.Outputs[formResponse.DbtProfileOutput].ConnType,
Database: profile.Outputs[formResponse.DbtProfileOutput].Database,
Schema: formResponse.Schema,
}
}
default:
{
log.Fatalf("Unsupported connection type %v\n", profile.Outputs[formResponse.DbtProfileOutput].ConnType)
Expand All @@ -57,6 +65,14 @@ func SetConnectionDetails(formResponse FormResponse) shared.ConnectionDetails {
Dataset: formResponse.Dataset,
}
}
case "duckdb":
{
connectionDetails = shared.ConnectionDetails{
ConnType: formResponse.Warehouse,
Database: formResponse.Database,
Schema: formResponse.Schema,
}
}
default:
{
log.Fatalf("Unsupported connection type %v\n", formResponse.Warehouse)
Expand Down
1 change: 1 addition & 0 deletions shared/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ type ConnectionDetails struct {
Schema string
Project string
Dataset string
Path string
}
11 changes: 11 additions & 0 deletions sourcerer/connect_to_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"cloud.google.com/go/bigquery"
_ "github.com/marcboeker/go-duckdb"
_ "github.com/snowflakedb/gosnowflake"
)

Expand Down Expand Up @@ -38,3 +39,13 @@ func (bqc *BqConn) ConnectToDB(ctx context.Context) (err error) {
}
return err
}

func (dc *DuckConn) ConnectToDB(ctx context.Context) (err error) {
_, dc.Cancel = context.WithTimeout(ctx, 1*time.Minute)
defer dc.Cancel()
dc.Db, err = sql.Open("duckdb", dc.Path)
if err != nil {
log.Fatalf("Could not connect to DuckDB %v\n", err)
}
return err
}
18 changes: 18 additions & 0 deletions sourcerer/get_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,21 @@ func (bqc *BqConn) GetColumns(ctx context.Context, t shared.SourceTable) ([]shar
}
return cs, nil
}

func (dc *DuckConn) GetColumns(ctx context.Context, t shared.SourceTable) ([]shared.Column, error) {
var cs []shared.Column
q := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '%s' AND table_name = '%s'", dc.Schema, t.Name)
rows, err := dc.Db.QueryContext(ctx, q)
if err != nil {
log.Fatalf("Error fetching columns for table %s: %v\n", t.Name, err)
}
defer rows.Close()
for rows.Next() {
c := shared.Column{}
if err := rows.Scan(&c.Name, &c.DataType); err != nil {
log.Fatalf("Error scanning columns for table %s: %v\n", t.Name, err)
}
cs = append(cs, c)
}
return cs, nil
}
16 changes: 16 additions & 0 deletions sourcerer/get_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ type BqConn struct {
Cancel context.CancelFunc
}

type DuckConn struct {
Path string
Database string
Schema string
Db *sql.DB
Cancel context.CancelFunc
}

func GetConn(cd shared.ConnectionDetails) (DbConn, error) {
switch cd.ConnType {
case "snowflake":
Expand All @@ -49,6 +57,14 @@ func GetConn(cd shared.ConnectionDetails) (DbConn, error) {
Project: cd.Project,
Dataset: cd.Dataset,
}, nil
case "duckdb":
{
return &DuckConn{
Path: cd.Path,
Database: cd.Database,
Schema: cd.Schema,
}, nil
}
default:
return nil, errors.New("unsupported connection type")
}
Expand Down
24 changes: 24 additions & 0 deletions sourcerer/get_sources.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,27 @@ func (bqc *BqConn) GetSources(ctx context.Context) (shared.SourceTables, error)
bqc.PutColumnsOnTables(ctx, ts)
return ts, nil
}

func (dc *DuckConn) GetSources(ctx context.Context) (shared.SourceTables, error) {
ts := shared.SourceTables{}
err := dc.ConnectToDB(ctx)
defer dc.Cancel()
if err != nil {
log.Fatalf("Couldn't connect to database: %v\n", err)
}
q := fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema = '%s'", dc.Schema)
rows, err := dc.Db.QueryContext(ctx, q)
if err != nil {
log.Fatalf("Error fetching tables: %v\n", err)
}
defer rows.Close()
for rows.Next() {
var table shared.SourceTable
if err := rows.Scan(&table.Name); err != nil {
log.Fatalf("Error scanning tables: %v\n", err)
}
ts.SourceTables = append(ts.SourceTables, table)
}
dc.PutColumnsOnTables(ctx, ts)
return ts, nil
}
61 changes: 24 additions & 37 deletions sourcerer/put_columns_on_tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ import (
)

func (sfc *SfConn) PutColumnsOnTables(ctx context.Context, tables shared.SourceTables) {
mutex := sync.Mutex{}

var wg sync.WaitGroup
wg.Add(len(tables.SourceTables))

dataTypeGroupMap := map[string]string{
"(text|char)": "text",
"(float|int|num)": "numbers",
Expand All @@ -24,40 +19,10 @@ func (sfc *SfConn) PutColumnsOnTables(ctx context.Context, tables shared.SourceT
"date": "datetimes",
"timestamp": "timestamps",
}

for i := range tables.SourceTables {
go func(i int) {
defer wg.Done()

columns, err := sfc.GetColumns(ctx, tables.SourceTables[i])
if err != nil {
log.Fatalf("Error fetching columns for table %s: %v\n", tables.SourceTables[i].Name, err)
return
}

mutex.Lock()
tables.SourceTables[i].Columns = columns
tables.SourceTables[i].DataTypeGroups = make(map[string][]shared.Column)
// Create a map of data types groups to hold column slices by data type
// This lets us group columns by their data type e.g. in templates
for j := range tables.SourceTables[i].Columns {
for k, v := range dataTypeGroupMap {
r, _ := regexp.Compile(fmt.Sprintf(`(?i).*%s.*`, k))
if r.MatchString(tables.SourceTables[i].Columns[j].DataType) {
tables.SourceTables[i].DataTypeGroups[v] = append(tables.SourceTables[i].DataTypeGroups[v], tables.SourceTables[i].Columns[j])
}
}
}
mutex.Unlock()
}(i)
}
wg.Wait()
columnPutter(ctx, tables, sfc, dataTypeGroupMap)
}

func (bqc *BqConn) PutColumnsOnTables(ctx context.Context, tables shared.SourceTables) {
mutex := sync.Mutex{}
var wg sync.WaitGroup
wg.Add(len(tables.SourceTables))
dataTypeGroupMap := map[string]string{
"(string)": "text",
"(float|int)": "numbers",
Expand All @@ -66,14 +31,36 @@ func (bqc *BqConn) PutColumnsOnTables(ctx context.Context, tables shared.SourceT
"(date)": "datetimes",
"(timestamp)": "timestamps",
}
columnPutter(ctx, tables, bqc, dataTypeGroupMap)
}

func (dc *DuckConn) PutColumnsOnTables(ctx context.Context, tables shared.SourceTables) {
dataTypeGroupMap := map[string]string{
"(string|varchar)": "text",
"(float|int)": "numbers",
"(bool)": "booleans",
"(json)": "json",
"(date)": "datetimes",
"(timestamp)": "timestamps",
}
columnPutter(ctx, tables, dc, dataTypeGroupMap)
}

func columnPutter(ctx context.Context, tables shared.SourceTables, conn DbConn, dataTypeGroupMap map[string]string) {
mutex := sync.Mutex{}

var wg sync.WaitGroup
wg.Add(len(tables.SourceTables))
for i := range tables.SourceTables {
go func(i int) {
defer wg.Done()
columns, err := bqc.GetColumns(ctx, tables.SourceTables[i])

columns, err := conn.GetColumns(ctx, tables.SourceTables[i])
if err != nil {
log.Fatalf("Error fetching columns for table %s: %v\n", tables.SourceTables[i].Name, err)
return
}

mutex.Lock()
tables.SourceTables[i].Columns = columns
tables.SourceTables[i].DataTypeGroups = make(map[string][]shared.Column)
Expand Down

0 comments on commit 1501f11

Please sign in to comment.