Skip to content

Commit

Permalink
update CreateUser interface method
Browse files Browse the repository at this point in the history
  • Loading branch information
colindickson committed Nov 24, 2023
1 parent 876713e commit fad7d61
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 13 deletions.
2 changes: 1 addition & 1 deletion db/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type dialect interface {
Flush(tx Tx, ctx context.Context, l *Loader, outputModuleHash string, lastFinalBlock uint64) (int, error)
Revert(tx Tx, ctx context.Context, l *Loader, lastValidFinalBlock uint64) error
OnlyInserts() bool
CreateUser(ctx context.Context, username string, password string, database string, readOnly bool) string
CreateUser(tx Tx, ctx context.Context, l *Loader, username string, password string, database string, readOnly bool) error
}

var driverDialect = map[string]dialect{
Expand Down
21 changes: 14 additions & 7 deletions db/dialect_clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,27 @@ func (d clickhouseDialect) OnlyInserts() bool {
return true
}

func (d clickhouseDialect) CreateUser(ctx context.Context, username, password, _database string, readOnly bool) string {
func (d clickhouseDialect) CreateUser(tx Tx, ctx context.Context, l *Loader, username string, password string, _database string, readOnly bool) error {
user, pass := EscapeIdentifier(username), EscapeIdentifier(password)
var q string
if readOnly {
// SQL statements for creating a read-only user in ClickHouse
return fmt.Sprintf(`
q = fmt.Sprintf(`
CREATE USER %s IDENTIFIED BY '%s';
GRANT SELECT ON *.* TO %s;
`, username, password, username)
`, user, pass, user)
} else {
// SQL statement for creating a read-write user in ClickHouse
return fmt.Sprintf(`
q = fmt.Sprintf(`
CREATE USER %s IDENTIFIED BY '%s';
GRANT ALL ON *.* TO %s;
`, username, password, username)
`, user, pass, user)
}

_, err := tx.ExecContext(ctx, q)
if err != nil {
return fmt.Errorf("executing query %q: %w", q, err)
}

return nil
}

func convertOpToClickhouseValues(o *Operation) ([]any, error) {
Expand Down
16 changes: 11 additions & 5 deletions db/dialect_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,21 +235,27 @@ func (d postgresDialect) OnlyInserts() bool {
return false
}

func (d postgresDialect) CreateUser(ctx context.Context, username, password string, database string, readOnly bool) string {
func (d postgresDialect) CreateUser(tx Tx, ctx context.Context, l *Loader, username string, password string, database string, readOnly bool) error {
user, pass, db := EscapeIdentifier(username), EscapeIdentifier(password), EscapeIdentifier(database)
var q string
if readOnly {
// SQL statements for creating a read-only user
return fmt.Sprintf(`
q = fmt.Sprintf(`
CREATE USER %s WITH PASSWORD '%s';
GRANT CONNECT ON DATABASE %s TO %s;
GRANT USAGE ON SCHEMA public TO %s;
GRANT SELECT ON ALL TABLES IN SCHEMA public TO %s;
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO %s;
`, user, pass, db, user, user, user, user)
} else {
// SQL statement for creating a read-write user
return fmt.Sprintf("CREATE USER %s WITH PASSWORD '%s'; GRANT ALL PRIVILEGES ON DATABASE %s TO %s;", user, pass, db, user)
q = fmt.Sprintf("CREATE USER %s WITH PASSWORD '%s'; GRANT ALL PRIVILEGES ON DATABASE %s TO %s;", user, pass, db, user)
}

_, err := tx.ExecContext(ctx, q)
if err != nil {
return fmt.Errorf("executing query %q: %w", q, err)
}

return nil
}

func (d postgresDialect) historyTable(schema string) string {
Expand Down
34 changes: 34 additions & 0 deletions db/user.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package db

import (
"context"
"fmt"

"go.uber.org/zap"
)

func (l *Loader) CreateUser(ctx context.Context, username string, password string, database string, readOnly bool) (err error) {
tx, err := l.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to being db transaction: %w", err)
}
defer func() {
if err != nil {
if err := tx.Rollback(); err != nil {
l.logger.Warn("failed to rollback transaction", zap.Error(err))
}
}
}()

err = l.getDialect().CreateUser(tx, ctx, l, username, password, database, readOnly)
if err != nil {
return fmt.Errorf("create user: %w", err)
}

if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit db transaction: %w", err)
}
l.reset()

return nil
}

0 comments on commit fad7d61

Please sign in to comment.