Skip to content

Commit

Permalink
Merge pull request #17 from kanmu/support_connector
Browse files Browse the repository at this point in the history
Support driver.Connector
  • Loading branch information
winebarrel authored Nov 27, 2024
2 parents 84bca5e + 9004e46 commit e3780a3
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
24 changes: 24 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,30 @@ type txDriver struct {
dsn string
}

type txConnector struct {
driver *txDriver
dsn string
}

func (c *txConnector) Driver() driver.Driver {
return c.driver
}

func (c *txConnector) Connect(ctx context.Context) (driver.Conn, error) {
return c.driver.Open(c.dsn)
}

func NewConnector(dsn, srcDrv, srcDsn string) driver.Connector {
return &txConnector{
driver: &txDriver{
dsn: srcDsn,
drv: srcDrv,
conns: make(map[string]*conn),
},
dsn: dsn,
}
}

func (d *txDriver) Open(dsn string) (driver.Conn, error) {
d.Lock()
defer d.Unlock()
Expand Down
30 changes: 30 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,36 @@ func TestShouldRunWithinTransaction(t *testing.T) {
}
}

func TestShouldRunWithinTransactionForOpenDB(t *testing.T) {
t.Parallel()
var count int
db1 := sql.OpenDB(pgtxdb.NewConnector("one", "pgx", "postgres://pgtxdbtest@localhost:5432/pgtxdbtest?sslmode=disable"))
defer db1.Close()

_, err := db1.Exec(`INSERT INTO app_user(username, email) VALUES('txdb', '[email protected]')`)
if err != nil {
t.Fatalf("failed to insert an app_user: %s", err)
}
err = db1.QueryRow("SELECT COUNT(id) FROM app_user").Scan(&count)
if err != nil {
t.Fatalf("failed to count users: %s", err)
}
if count != 1 {
t.Fatalf("expected 1 user to be in database, but got %d", count)
}

db2 := sql.OpenDB(pgtxdb.NewConnector("two", "pgx", "postgres://pgtxdbtest@localhost:5432/pgtxdbtest?sslmode=disable"))
defer db2.Close()

err = db2.QueryRow("SELECT COUNT(id) FROM app_user").Scan(&count)
if err != nil {
t.Fatalf("failed to count app_user: %s", err)
}
if count != 0 {
t.Fatalf("expected 0 user to be in database, but got %d", count)
}
}

func TestShouldNotHoldConnectionForRows(t *testing.T) {
t.Parallel()
db, err := sql.Open("pgtxdb", "three")
Expand Down

0 comments on commit e3780a3

Please sign in to comment.