From fad7d617f0d81ebbb3da76db2a2bc619c4df8b9a Mon Sep 17 00:00:00 2001 From: colindickson Date: Fri, 24 Nov 2023 14:53:38 -0500 Subject: [PATCH] update CreateUser interface method --- db/dialect.go | 2 +- db/dialect_clickhouse.go | 21 ++++++++++++++------- db/dialect_postgres.go | 16 +++++++++++----- db/user.go | 34 ++++++++++++++++++++++++++++++++++ 4 files changed, 60 insertions(+), 13 deletions(-) create mode 100644 db/user.go diff --git a/db/dialect.go b/db/dialect.go index a6a23f8..4e3ea5d 100644 --- a/db/dialect.go +++ b/db/dialect.go @@ -26,7 +26,7 @@ type dialect interface { Flush(tx Tx, ctx context.Context, l *Loader, outputModuleHash string, lastFinalBlock uint64) (int, error) Revert(tx Tx, ctx context.Context, l *Loader, lastValidFinalBlock uint64) error OnlyInserts() bool - CreateUser(ctx context.Context, username string, password string, database string, readOnly bool) string + CreateUser(tx Tx, ctx context.Context, l *Loader, username string, password string, database string, readOnly bool) error } var driverDialect = map[string]dialect{ diff --git a/db/dialect_clickhouse.go b/db/dialect_clickhouse.go index ec125f3..c50d2db 100644 --- a/db/dialect_clickhouse.go +++ b/db/dialect_clickhouse.go @@ -131,20 +131,27 @@ func (d clickhouseDialect) OnlyInserts() bool { return true } -func (d clickhouseDialect) CreateUser(ctx context.Context, username, password, _database string, readOnly bool) string { +func (d clickhouseDialect) CreateUser(tx Tx, ctx context.Context, l *Loader, username string, password string, _database string, readOnly bool) error { + user, pass := EscapeIdentifier(username), EscapeIdentifier(password) + var q string if readOnly { - // SQL statements for creating a read-only user in ClickHouse - return fmt.Sprintf(` + q = fmt.Sprintf(` CREATE USER %s IDENTIFIED BY '%s'; GRANT SELECT ON *.* TO %s; - `, username, password, username) + `, user, pass, user) } else { - // SQL statement for creating a read-write user in ClickHouse - return fmt.Sprintf(` + q = fmt.Sprintf(` CREATE USER %s IDENTIFIED BY '%s'; GRANT ALL ON *.* TO %s; - `, username, password, username) + `, user, pass, user) } + + _, err := tx.ExecContext(ctx, q) + if err != nil { + return fmt.Errorf("executing query %q: %w", q, err) + } + + return nil } func convertOpToClickhouseValues(o *Operation) ([]any, error) { diff --git a/db/dialect_postgres.go b/db/dialect_postgres.go index 9af3062..f0354fd 100644 --- a/db/dialect_postgres.go +++ b/db/dialect_postgres.go @@ -235,11 +235,11 @@ func (d postgresDialect) OnlyInserts() bool { return false } -func (d postgresDialect) CreateUser(ctx context.Context, username, password string, database string, readOnly bool) string { +func (d postgresDialect) CreateUser(tx Tx, ctx context.Context, l *Loader, username string, password string, database string, readOnly bool) error { user, pass, db := EscapeIdentifier(username), EscapeIdentifier(password), EscapeIdentifier(database) + var q string if readOnly { - // SQL statements for creating a read-only user - return fmt.Sprintf(` + q = fmt.Sprintf(` CREATE USER %s WITH PASSWORD '%s'; GRANT CONNECT ON DATABASE %s TO %s; GRANT USAGE ON SCHEMA public TO %s; @@ -247,9 +247,15 @@ func (d postgresDialect) CreateUser(ctx context.Context, username, password stri ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO %s; `, user, pass, db, user, user, user, user) } else { - // SQL statement for creating a read-write user - return fmt.Sprintf("CREATE USER %s WITH PASSWORD '%s'; GRANT ALL PRIVILEGES ON DATABASE %s TO %s;", user, pass, db, user) + q = fmt.Sprintf("CREATE USER %s WITH PASSWORD '%s'; GRANT ALL PRIVILEGES ON DATABASE %s TO %s;", user, pass, db, user) } + + _, err := tx.ExecContext(ctx, q) + if err != nil { + return fmt.Errorf("executing query %q: %w", q, err) + } + + return nil } func (d postgresDialect) historyTable(schema string) string { diff --git a/db/user.go b/db/user.go new file mode 100644 index 0000000..56e5d3e --- /dev/null +++ b/db/user.go @@ -0,0 +1,34 @@ +package db + +import ( + "context" + "fmt" + + "go.uber.org/zap" +) + +func (l *Loader) CreateUser(ctx context.Context, username string, password string, database string, readOnly bool) (err error) { + tx, err := l.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to being db transaction: %w", err) + } + defer func() { + if err != nil { + if err := tx.Rollback(); err != nil { + l.logger.Warn("failed to rollback transaction", zap.Error(err)) + } + } + }() + + err = l.getDialect().CreateUser(tx, ctx, l, username, password, database, readOnly) + if err != nil { + return fmt.Errorf("create user: %w", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit db transaction: %w", err) + } + l.reset() + + return nil +}