Skip to content

Commit

Permalink
feat(custom-db): support custom db from url
Browse files Browse the repository at this point in the history
  • Loading branch information
code-R committed Jan 31, 2024
1 parent 3c53cc6 commit 233515d
Showing 1 changed file with 34 additions and 14 deletions.
48 changes: 34 additions & 14 deletions database/neo4j/neo4j.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"bytes"
"fmt"
"io"
"log"
neturl "net/url"
"strconv"
"strings"
"sync/atomic"

"github.com/golang-migrate/migrate/v4/database"
Expand Down Expand Up @@ -34,6 +36,7 @@ type Config struct {
MigrationsLabel string
MultiStatement bool
MultiStatementMaxSize int
DatabaseName string
}

type Neo4j struct {
Expand Down Expand Up @@ -91,7 +94,14 @@ func (n *Neo4j) Open(url string) (database.Driver, error) {
return nil, err
}

dbName := strings.Trim(uri.Path, "/")
if dbName == "" {
log.Printf("Using default neo4j database")
} else {
log.Printf("Using db name %s", dbName)
}
return WithInstance(driver, &Config{
DatabaseName: dbName,
MigrationsLabel: DefaultMigrationsLabel,
MultiStatement: multi,
MultiStatementMaxSize: multiStatementMaxSize,
Expand All @@ -118,11 +128,23 @@ func (n *Neo4j) Unlock() error {
return nil
}

func (n *Neo4j) Run(migration io.Reader) (err error) {
session, err := n.driver.Session(neo4j.AccessModeWrite)
if err != nil {
return err
func getReadSessionConfig(dbName string) neo4j.SessionConfig {
return neo4j.SessionConfig{
AccessMode: neo4j.AccessModeRead,
DatabaseName: dbName,
}
}

func getWriteSessionConfig(dbName string) neo4j.SessionConfig {
return neo4j.SessionConfig{
AccessMode: neo4j.AccessModeWrite,
DatabaseName: dbName,
}
}

func (n *Neo4j) Run(migration io.Reader) (err error) {
session := n.driver.NewSession(
getWriteSessionConfig(n.config.DatabaseName))
defer func() {
if cerr := session.Close(); cerr != nil {
err = multierror.Append(err, cerr)
Expand Down Expand Up @@ -166,10 +188,8 @@ func (n *Neo4j) Run(migration io.Reader) (err error) {
}

func (n *Neo4j) SetVersion(version int, dirty bool) (err error) {
session, err := n.driver.Session(neo4j.AccessModeWrite)
if err != nil {
return err
}
session := n.driver.NewSession(
getWriteSessionConfig(n.config.DatabaseName))
defer func() {
if cerr := session.Close(); cerr != nil {
err = multierror.Append(err, cerr)
Expand All @@ -191,10 +211,8 @@ type MigrationRecord struct {
}

func (n *Neo4j) Version() (version int, dirty bool, err error) {
session, err := n.driver.Session(neo4j.AccessModeRead)
if err != nil {
return database.NilVersion, false, err
}
session := n.driver.NewSession(
getReadSessionConfig(n.config.DatabaseName))
defer func() {
if cerr := session.Close(); cerr != nil {
err = multierror.Append(err, cerr)
Expand Down Expand Up @@ -239,7 +257,8 @@ ORDER BY COALESCE(sm.ts, datetime({year: 0})) DESC, sm.version DESC LIMIT 1`,
}

func (n *Neo4j) Drop() (err error) {
session, err := n.driver.Session(neo4j.AccessModeWrite)
session := n.driver.NewSession(
getWriteSessionConfig(n.config.DatabaseName))
if err != nil {
return err
}
Expand All @@ -256,7 +275,8 @@ func (n *Neo4j) Drop() (err error) {
}

func (n *Neo4j) ensureVersionConstraint() (err error) {
session, err := n.driver.Session(neo4j.AccessModeWrite)
session := n.driver.NewSession(
getWriteSessionConfig(n.config.DatabaseName))
if err != nil {
return err
}
Expand Down

0 comments on commit 233515d

Please sign in to comment.