diff --git a/README.md b/README.md index 7e3d2c6..6a69f28 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ cd dbs dolt clone ``` -Finally you can create the dbs directory as shown above and then create the database in code using a SQL `CREATE TABLE` statement +Finally, you can create the dbs directory as shown above and then create the database in code using a SQL `CREATE TABLE` statement ### Connecting to the Database @@ -61,3 +61,31 @@ clientfoundrows - If set to true, returns the number of matching rows instead of #### Example DSN `file:///path/to/dbs?commitname=Your%20Name&commitemail=your@email.com&database=databasename` + +### Multi-Statement Support + +If you pass the `multistatements=true` parameter in the DSN, you can execute multiple statements in one query. The returned +rows allow you to iterate over the returned result sets by using the `NextResultSet` method, just like you can with the +MySQL driver. + +```go +rows, err := db.Query("SELECT * from someTable; SELECT * from anotherTable;") +// If an error is returned, it means it came from the first statement +if err != nil { + panic(err) +} + +for rows.Next() { + // process the first result set +} + +if rows.NextResultSet() { + for rows.Next() { + // process the second result set + } +} else { + // If NextResultSet returns false when there were more statements, it means there was an error, + // which you can access through rows.Err() + panic(rows.Err()) +} +``` diff --git a/conn.go b/conn.go index 591ceca..7ae4b03 100644 --- a/conn.go +++ b/conn.go @@ -5,12 +5,11 @@ import ( "database/sql" "database/sql/driver" "fmt" - "io" "time" "github.com/dolthub/dolt/go/cmd/dolt/commands/engine" - gms "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/vitess/go/vt/sqlparser" ) var _ driver.Conn = (*DoltConn)(nil) @@ -22,71 +21,54 @@ type DoltConn struct { DataSource *DoltDataSource } -// Prepare returns a prepared statement, bound to this connection. +// Prepare packages up |query| as a *doltStmt so it can be executed. If multistatements mode +// has been enabled, then a *doltMultiStmt will be returned, capable of executing multiple statements. func (d *DoltConn) Prepare(query string) (driver.Stmt, error) { - multiStatements := d.DataSource.ParamIsTrue(MultiStatementsParam) + // Reuse the same ctx instance, but update the QueryTime to the current time. + // Statements are executed serially on a connection, so it's safe to reuse + // the same ctx instance and update the time. + d.gmsCtx.SetQueryTime(time.Now()) - if multiStatements { - scanner := gms.NewMysqlParser() - parsed, prequery, remainder, err := scanner.Parse(d.gmsCtx, query, true) - if err != nil { - return nil, translateError(err) - } + if d.DataSource.ParamIsTrue(MultiStatementsParam) { + return d.prepareMultiStatement(query) + } else { + return d.prepareSingleStatement(query) + } +} - for { - if len(remainder) == 0 { - query = prequery - break - } - - err = func() error { - var rowIter gms.RowIter - _, rowIter, err = d.se.GetUnderlyingEngine().QueryWithBindings(d.gmsCtx, prequery, parsed, nil) - if err != nil { - return translateError(err) - } - defer rowIter.Close(d.gmsCtx) - - for { - _, err := rowIter.Next(d.gmsCtx) - if err == io.EOF { - break - } else if err != nil { - return translateError(err) - } - } - - return nil - }() - if err != nil { - return nil, err - } - - parsed, prequery, remainder, err = scanner.Parse(d.gmsCtx, remainder, true) - if err != nil { - return nil, translateError(err) - } - } - if prequery != "" { - query = prequery +// prepareSingleStatement creates a doltStmt from |query|. +func (d *DoltConn) prepareSingleStatement(query string) (*doltStmt, error) { + return &doltStmt{ + query: query, + se: d.se, + gmsCtx: d.gmsCtx, + }, nil +} + +// prepareMultiStatement creates a doltStmt from each individual statement in |query|. +func (d *DoltConn) prepareMultiStatement(query string) (*doltMultiStmt, error) { + var doltMultiStmt doltMultiStmt + scanner := gms.NewMysqlParser() + + remainder := query + var err error + for remainder != "" { + _, query, remainder, err = scanner.Parse(d.gmsCtx, remainder, true) + if err == sqlparser.ErrEmpty { + // Skip over any empty statements + continue + } else if err != nil { + return nil, translateError(err) } - } - if len(query) > 0 { - _, err := d.se.GetUnderlyingEngine().PrepareQuery(d.gmsCtx, query) + doltStmt, err := d.prepareSingleStatement(query) if err != nil { return nil, translateError(err) } + doltMultiStmt.stmts = append(doltMultiStmt.stmts, doltStmt) } - // Reuse the same ctx instance, but update the QueryTime to the current time. Since statements are - // executed serially on a connection, it's safe to reuse the same ctx instance and update the time. - d.gmsCtx.SetQueryTime(time.Now()) - return &doltStmt{ - query: query, - se: d.se, - gmsCtx: d.gmsCtx, - }, nil + return &doltMultiStmt, nil } // Close releases the resources held by the DoltConn instance diff --git a/driver.go b/driver.go index 84813b9..74be90f 100644 --- a/driver.go +++ b/driver.go @@ -42,7 +42,6 @@ type doltDriver struct { // // The path needs to point to a directory whose subdirectories are dolt databases. If a "Create Database" command is // run a new subdirectory will be created in this path. -// The supported parameters are func (d *doltDriver) Open(dataSource string) (driver.Conn, error) { ctx := context.Background() var fs filesys.Filesys = filesys.LocalFS @@ -89,7 +88,7 @@ func (d *doltDriver) Open(dataSource string) (driver.Conn, error) { ServerUser: "root", Autocommit: true, } - + se, err := engine.NewSqlEngine(ctx, mrEnv, seCfg) if err != nil { return nil, err @@ -122,16 +121,16 @@ func (d *doltDriver) Open(dataSource string) (driver.Conn, error) { // with initialized environments for each of those subfolder data repositories. subfolders whose name starts with '.' are // skipped. func LoadMultiEnvFromDir( - ctx context.Context, - cfg config.ReadWriteConfig, - fs filesys.Filesys, - path, version string, + ctx context.Context, + cfg config.ReadWriteConfig, + fs filesys.Filesys, + path, version string, ) (*env.MultiRepoEnv, error) { multiDbDirFs, err := fs.WithWorkingDir(path) if err != nil { return nil, errhand.VerboseErrorFromError(err) } - + return env.MultiEnvForDirectory(ctx, cfg, multiDbDirFs, version, nil) } diff --git a/example/main.go b/example/main.go index 3e5fffe..65662b7 100644 --- a/example/main.go +++ b/example/main.go @@ -39,16 +39,16 @@ func main() { db, err := sql.Open("dolt", dataSource) errExit("failed to open database using the dolt driver: %w", err) - err = printQuery(ctx, db, "CREATE DATABASE IF NOT EXISTS testdb;USE testdb;") + err = printQuery(ctx, db, "CREATE DATABASE IF NOT EXISTS testdb; USE testdb;") errExit("", err) err = printQuery(ctx, db, "USE testdb;") errExit("", err) - printQuery(ctx, db, `CREATE TABLE IF NOT EXISTS t2( - pk int primary key auto_increment, - c1 varchar(32) -)`) + err = printQuery(ctx, db, `CREATE TABLE IF NOT EXISTS t2( + pk int primary key auto_increment, + c1 varchar(32) + )`) errExit("", err) printQuery(ctx, db, "SHOW TABLES;") @@ -63,21 +63,22 @@ func main() { fmt.Println(result.LastInsertId()) err = printQuery(ctx, db, `CREATE TABLE IF NOT EXISTS t1 ( - pk int PRIMARY KEY, - c1 varchar(512), - c2 float, - c3 bool, - c4 datetime -);`) + pk int PRIMARY KEY, + c1 varchar(512), + c2 float, + c3 bool, + c4 datetime + );`) + errExit("", err) err = printQuery(ctx, db, "SELECT * FROM t1;") errExit("", err) err = printQuery(ctx, db, `REPLACE INTO t1 VALUES -(1, 'this is a test', 0, 0, '1998-01-23 12:45:56'), -(2, 'it is only a test', 1.0, 1, '2010-12-31 01:15:00'), -(3, NULL, 3.335, 0, NULL), -(4, 'something something', 3.5, 1, '2015-04-03 14:00:45');`) + (1, 'this is a test', 0, 0, '1998-01-23 12:45:56'), + (2, 'it is only a test', 1.0, 1, '2010-12-31 01:15:00'), + (3, NULL, 3.335, 0, NULL), + (4, 'something something', 3.5, 1, '2015-04-03 14:00:45');`) errExit("", err) err = printQuery(ctx, db, "SELECT * FROM t1;") diff --git a/go.mod b/go.mod index c2e5532..fe34147 100644 --- a/go.mod +++ b/go.mod @@ -5,9 +5,9 @@ go 1.22.2 toolchain go1.22.3 require ( - github.com/dolthub/dolt/go v0.40.5-0.20240604165632-02f450318cb3 - github.com/dolthub/go-mysql-server v0.18.2-0.20240604161217-d1dca79a32b8 - github.com/dolthub/vitess v0.0.0-20240603172811-467efd832e48 + github.com/dolthub/dolt/go v0.40.5-0.20240702155756-bcf4dd5f5cc1 + github.com/dolthub/go-mysql-server v0.18.2-0.20240702022058-d7eb602c04ee + github.com/dolthub/vitess v0.0.0-20240709194214-7926ea9d425d github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d github.com/stretchr/testify v1.8.4 gorm.io/driver/mysql v1.5.6 @@ -71,6 +71,7 @@ require ( github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect + github.com/mohae/uvarint v0.0.0-20160208145430-c3f9e62bf2b0 // indirect github.com/oracle/oci-go-sdk/v65 v65.55.0 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -94,17 +95,17 @@ require ( go.opentelemetry.io/otel/trace v1.23.1 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.26.0 // indirect - golang.org/x/crypto v0.21.0 // indirect + golang.org/x/crypto v0.23.0 // indirect golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3 // indirect - golang.org/x/mod v0.15.0 // indirect - golang.org/x/net v0.23.0 // indirect + golang.org/x/mod v0.17.0 // indirect + golang.org/x/net v0.25.0 // indirect golang.org/x/oauth2 v0.17.0 // indirect - golang.org/x/sync v0.6.0 // indirect - golang.org/x/sys v0.18.0 // indirect - golang.org/x/term v0.18.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/sync v0.7.0 // indirect + golang.org/x/sys v0.20.0 // indirect + golang.org/x/term v0.20.0 // indirect + golang.org/x/text v0.16.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.18.0 // indirect + golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect google.golang.org/api v0.164.0 // indirect google.golang.org/appengine v1.6.8 // indirect diff --git a/go.sum b/go.sum index b4fcfd3..aed7f82 100644 --- a/go.sum +++ b/go.sum @@ -246,8 +246,8 @@ github.com/dimchansky/utfbom v1.1.0/go.mod h1:rO41eb7gLfo8SF1jd9F8HplJm1Fewwi4mQ github.com/dimchansky/utfbom v1.1.1/go.mod h1:SxdoEBH5qIqFocHMyGOXVAybYJdr71b1Q/j0mACtrfE= github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko= github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= -github.com/dolthub/dolt/go v0.40.5-0.20240604165632-02f450318cb3 h1:dprIlHxkwCRjzBVUop1FgwUDe3aiF118XVf6n866B/E= -github.com/dolthub/dolt/go v0.40.5-0.20240604165632-02f450318cb3/go.mod h1:bDWUUxoq/7AxDszTKFneulNc6Uh1PLqK897LjqKfoWY= +github.com/dolthub/dolt/go v0.40.5-0.20240702155756-bcf4dd5f5cc1 h1:zja4D6qChO7OZqh00buv9FTVu5pYzLEq1jptxpATcQE= +github.com/dolthub/dolt/go v0.40.5-0.20240702155756-bcf4dd5f5cc1/go.mod h1:QaKI/d6K38jAtq2gn11lQz+rkHECSwlbEzHyWSts+g0= github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20240212175631-02e9f99a3a9b h1:VehmKUF425NgXpRQVYMPzJx6rWZaJ2cbTwTTwXlrbiM= github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20240212175631-02e9f99a3a9b/go.mod h1:gHeHIDGU7em40EhFTliq62pExFcc1hxDTIZ9g5UqXYM= github.com/dolthub/flatbuffers v1.13.0-dh.1 h1:OWJdaPep22N52O/0xsUevxJ6Qfw1M2txCjZPOdjXybE= @@ -258,20 +258,22 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e h1:kPsT4a47cw1+y/N5SSCkma7FhAPw7KeGmD6c9PBZW9Y= github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e/go.mod h1:KPUcpx070QOfJK1gNe0zx4pA5sicIK1GMikIGLKC168= -github.com/dolthub/go-mysql-server v0.18.2-0.20240604161217-d1dca79a32b8 h1:F9tnktZhnglXHYiG/tjnVMiLQSBZ8iJiflqgHg+ldKo= -github.com/dolthub/go-mysql-server v0.18.2-0.20240604161217-d1dca79a32b8/go.mod h1:GT7JcQavIf7bAO17/odujkgHM/N0t4b1HfAPBJ2jzXo= +github.com/dolthub/go-mysql-server v0.18.2-0.20240702022058-d7eb602c04ee h1:VYwVsWT3byEtq6W8ebAVO7cNCPUKeUNr590s/U6F3wo= +github.com/dolthub/go-mysql-server v0.18.2-0.20240702022058-d7eb602c04ee/go.mod h1:JahRYjx/Py6T/bWrnTu25CaGn94Df+McAuWGEG0shwU= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= -github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488 h1:0HHu0GWJH0N6a6keStrHhUAK5/o9LVfkh44pvsV4514= -github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488/go.mod h1:ehexgi1mPxRTk0Mok/pADALuHbvATulTh6gzr7NzZto= +github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE= +github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718/go.mod h1:ehexgi1mPxRTk0Mok/pADALuHbvATulTh6gzr7NzZto= github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTEtT5tOBsCuCrlYnLRKpbJVJkDbrTRhwQ= github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/maphash v0.1.0 h1:bsQ7JsF4FkkWyrP3oCnFJgrCUAFbFf3kOl4L/QxPDyQ= github.com/dolthub/maphash v0.1.0/go.mod h1:gkg4Ch4CdCDu5h6PMriVLawB7koZ+5ijb9puGMV50a4= github.com/dolthub/swiss v0.2.1 h1:gs2osYs5SJkAaH5/ggVJqXQxRXtWshF6uE0lgR/Y3Gw= github.com/dolthub/swiss v0.2.1/go.mod h1:8AhKZZ1HK7g18j7v7k6c5cYIGEZJcPn0ARsai8cUrh0= -github.com/dolthub/vitess v0.0.0-20240603172811-467efd832e48 h1:KfVnDVNytmTHeYZaQfUWZF/uE/fbLdLvXVdebQTPaMk= -github.com/dolthub/vitess v0.0.0-20240603172811-467efd832e48/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= +github.com/dolthub/vitess v0.0.0-20240626174323-4083c07f5e9c h1:Y3M0hPCUvT+5RTNbJLKywGc9aHIRCIlg+0NOhC91GYE= +github.com/dolthub/vitess v0.0.0-20240626174323-4083c07f5e9c/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= +github.com/dolthub/vitess v0.0.0-20240709194214-7926ea9d425d h1:qifIBMiYOCw/OLczNMBDg5ZMPEcEjrj5kSDeoyMXNBY= +github.com/dolthub/vitess v0.0.0-20240709194214-7926ea9d425d/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= @@ -610,6 +612,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= +github.com/mohae/uvarint v0.0.0-20160208145430-c3f9e62bf2b0 h1:fXRYk7YXVIBMGAHT+GmAcbiXrudXMPtqdLfbkVfUhkI= +github.com/mohae/uvarint v0.0.0-20160208145430-c3f9e62bf2b0/go.mod h1:+6ZKJfAk1B0oKLOwdzYuRVJn3upG1c7uOm5Ih7Rrkvc= github.com/montanaflynn/stats v0.6.6/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= github.com/ncw/swift v1.0.52/go.mod h1:23YIA4yWVnGwv2dQlN4bB7egfYX6YLn0Yo/S6zZO/ZM= @@ -795,8 +799,8 @@ golang.org/x/crypto v0.0.0-20220511200225-c6db032c6c88/go.mod h1:IxCIyHEi3zRg3s0 golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= -golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -822,8 +826,8 @@ golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+o golang.org/x/image v0.0.0-20200618115811-c13761719519/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20201208152932-35266b937fa6/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20210216034530-4410531fe030/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.10.0 h1:gXjUUtwtx5yOE0VKWq1CH4IJAClq4UGgUA3i+rpON9M= -golang.org/x/image v0.10.0/go.mod h1:jtrku+n79PfroUbvDdeUWMAI+heR786BofxrbiSF+J0= +golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= +golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -850,8 +854,8 @@ golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.5.0/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.15.0 h1:SernR4v+D55NyBH2QiEQrlBAnj1ECL6AGrA5+dPaMY8= -golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -905,8 +909,8 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= -golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -942,8 +946,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -1024,16 +1028,16 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= -golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= +golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw= +golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1047,8 +1051,8 @@ golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1121,8 +1125,8 @@ golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.18.0 h1:k8NLag8AGHnn+PHbl7g43CtqZAwG60vZkLqgyZgIHgQ= -golang.org/x/tools v0.18.0/go.mod h1:GL7B4CwcLLeo59yx/9UWWuNOW1n3VZ4f5axWfML7Lcg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/rows.go b/rows.go index c7802f0..2ca6a12 100644 --- a/rows.go +++ b/rows.go @@ -10,7 +10,68 @@ import ( "github.com/dolthub/go-mysql-server/sql/types" ) -var _ driver.Rows = (*doltRows)(nil) +// doltMultiRows implements driver.RowsNextResultSet by aggregating a set of individual +// doltRows instances. +type doltMultiRows struct { + rowSets []*doltRows + currentRowSet int +} + +var _ driver.RowsNextResultSet = (*doltMultiRows)(nil) + +func (d *doltMultiRows) Columns() []string { + if d.currentRowSet >= len(d.rowSets) { + return nil + } + + return d.rowSets[d.currentRowSet].Columns() +} + +// Close implements the driver.Rows interface. When Close is called on a doltMultiRows instance, +// it will close all individual doltRows instances that it contains. If any errors are encountered +// while closing the individual row sets, the first error will be returned, after attempting to close +// all row sets. +func (d *doltMultiRows) Close() error { + var retErr error + for _, rowSet := range d.rowSets { + if err := rowSet.Close(); err != nil { + retErr = err + } + } + return retErr +} + +func (d *doltMultiRows) Next(dest []driver.Value) error { + if d.currentRowSet >= len(d.rowSets) { + return io.EOF + } + + return d.rowSets[d.currentRowSet].Next(dest) +} + +func (d *doltMultiRows) HasNextResultSet() bool { + idx := d.currentRowSet + 1 + for ; idx < len(d.rowSets); idx++ { + if d.rowSets[idx].isQueryResultSet || d.rowSets[idx].err != nil { + return true + } + } + return false +} + +func (d *doltMultiRows) NextResultSet() error { + idx := d.currentRowSet + 1 + for ; idx < len(d.rowSets); idx++ { + if d.rowSets[idx].isQueryResultSet || d.rowSets[idx].err != nil { + // Update the current row set index when we find the next result set for a query. If we encountered an + // error running the statement earlier and saved an error in the row set, return that error now that the + // result set with the error has been requested. This matches the MySQL driver's behavior. + d.currentRowSet = idx + return d.rowSets[d.currentRowSet].err + } + } + return io.EOF +} type doltRows struct { sch gms.Schema @@ -18,8 +79,19 @@ type doltRows struct { gmsCtx *gms.Context columns []string + + // err holds any error encountered while trying to retrieve this result set + err error + + // isQueryResultSet indicates if this result set was generated by a statement that doesn't produce a result set. For + // example, an INSERT or DML statement doesn't return a result set, but we still keep track of a doltRows + // instance for their results in case an error was returned. This field is also used to skip over doltRows + // that are not result sets when calling NextResultSet() on a doltMultiRows instance. + isQueryResultSet bool } +var _ driver.Rows = (*doltRows)(nil) + // Columns returns the names of the columns. The number of columns of the result is inferred from the length of the // slice. If a particular column name isn't known, an empty string should be returned for that entry. func (rows *doltRows) Columns() []string { @@ -35,6 +107,10 @@ func (rows *doltRows) Columns() []string { // Close closes the rows iterator. func (rows *doltRows) Close() error { + if rows.rowIter == nil { + return nil + } + return translateError(rows.rowIter.Close(rows.gmsCtx)) } @@ -85,3 +161,40 @@ func (rows *doltRows) Next(dest []driver.Value) error { return nil } + +// peekableRowIter wrap another gms.RowIter and allows the caller to peek at results, without disturbing the order +// that results are returned from the Next() method. +type peekableRowIter struct { + iter gms.RowIter + peeks []gms.Row +} + +var _ gms.RowIter = (*peekableRowIter)(nil) + +// Peek returns the next row from this row iterator, without causing that row to be skipped from future calls +// to Next(). There is no limit on how many rows can be peeked. +func (p *peekableRowIter) Peek(ctx *gms.Context) (gms.Row, error) { + next, err := p.iter.Next(ctx) + if err != nil { + return nil, err + } + p.peeks = append(p.peeks, next) + + return next, nil +} + +// Next implements gms.RowIter +func (p *peekableRowIter) Next(ctx *gms.Context) (gms.Row, error) { + if len(p.peeks) > 0 { + peek := p.peeks[0] + p.peeks = p.peeks[1:] + return peek, nil + } + + return p.iter.Next(ctx) +} + +// Close implements gms.RowIter +func (p *peekableRowIter) Close(ctx *gms.Context) error { + return p.iter.Close(ctx) +} diff --git a/smoke_test.go b/smoke_test.go index 6d138d8..0819fc8 100644 --- a/smoke_test.go +++ b/smoke_test.go @@ -9,10 +9,44 @@ import ( "testing" "time" + _ "github.com/go-sql-driver/mysql" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// runTestsAgainstMySQL can be set to true to run tests against a MySQL database using the MySQL driver. +// This is useful to test behavior compatibility between the Dolt driver and the MySQL driver. We +// want the Dolt driver to have the same semantics/behavior as the MySQL driver, so that customers +// familiar with using the MySQL driver, or code already using the MySQL driver, can easily switch +// to the Dolt driver. When this option is enabled, the MySQL database connection can be configured +// using mysqlDsn below. +var runTestsAgainstMySQL = false + +// mysqlDsn specifies the connection string for a MySQL database. Used only when the +// runTestsAgainstMySQL variable above is enabled. +var mysqlDsn = "root@tcp(localhost:3306)/?charset=utf8mb4&parseTime=True&loc=Local&multiStatements=true" + +// TestPreparedStatements tests that values can be plugged into "?" placeholders in queries. +func TestPreparedStatements(t *testing.T) { + conn, cleanupFunc := initializeTestDatabaseConnection(t, false) + defer cleanupFunc() + + ctx := context.Background() + rows, err := conn.QueryContext(ctx, "create table prepTest (id int, name varchar(256));") + require.NoError(t, err) + for rows.Next() { + } + require.NoError(t, rows.Err()) + require.NoError(t, rows.Close()) + + rows, err = conn.QueryContext(ctx, "insert into prepTest VALUES (?, ?);", 10, "foo") + require.NoError(t, err) + for rows.Next() { + } + require.NoError(t, rows.Err()) + require.NoError(t, rows.Close()) +} + func TestMultiStatements(t *testing.T) { conn, cleanupFunc := initializeTestDatabaseConnection(t, false) defer cleanupFunc() @@ -35,6 +69,8 @@ func TestMultiStatements(t *testing.T) { var id int var name string + // NOTE: Because the first two statements are not queries and don't have real result sets, the current result set + // is automatically positioned at the third statement. require.True(t, rows.Next()) require.NoError(t, rows.Scan(&id, &name)) require.Equal(t, 1, id) @@ -51,12 +87,265 @@ func TestMultiStatements(t *testing.T) { require.NoError(t, rows.Err()) require.NoError(t, rows.Close()) - _, err = conn.QueryContext(ctx, "select * from testtable; select * from doesnotexist; select * from testtable") - require.Error(t, err) + rows, err = conn.QueryContext(ctx, "select * from testtable; select * from doesnotexist; select * from testtable") + require.NoError(t, err) + + // The first result set contains all the rows from testtable + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&id, &name)) + require.Equal(t, 1, id) + require.Equal(t, "aaron", name) + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&id, &name)) + require.Equal(t, 2, id) + require.Equal(t, "brian", name) + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&id, &name)) + require.Equal(t, 3, id) + require.Equal(t, "tim", name) + require.False(t, rows.Next()) + require.NoError(t, rows.Err()) + + // The second result set has an error + require.False(t, rows.NextResultSet()) + require.NotNil(t, rows.Err()) + // MySQL returns a slightly different error message than Dolt + if !runTestsAgainstMySQL { + require.Equal(t, "Error 1146: table not found: doesnotexist", rows.Err().Error()) + } else { + require.Equal(t, "Error 1146 (42S02): Table 'testdb.doesnotexist' doesn't exist", rows.Err().Error()) + } + + // The third result set should have more rows... but we can't access them after the + // error in the second result set. This is the same behavior as the MySQL driver + require.False(t, rows.NextResultSet()) + require.NotNil(t, rows.Err()) require.NoError(t, conn.Close()) } +// TestMultiStatementsExecContext tests that using ExecContext to run a multi-statement query works as expected and +// matches the behavior of the MySQL driver. +func TestMultiStatementsExecContext(t *testing.T) { + conn, cleanupFunc := initializeTestDatabaseConnection(t, false) + defer cleanupFunc() + + ctx := context.Background() + _, err := conn.ExecContext(ctx, "CREATE TABLE example_table (id int, name varchar(256));") + require.NoError(t, err) + + // ExecContext returns the results from the LAST statement executed. This differs from the behavior for QueryContext. + result, err := conn.ExecContext(ctx, "INSERT into example_table VALUES (999, 'boo'); "+ + "INSERT into example_table VALUES (998, 'foo'); INSERT into example_table VALUES (997, 'goo'), (996, 'loo');") + require.NoError(t, err) + rowsAffected, err := result.RowsAffected() + require.NoError(t, err) + require.EqualValues(t, 2, rowsAffected) + + // Assert that all statements were correctly executed + requireResults(t, conn, "SELECT * FROM example_table ORDER BY id;", + [][]any{{996, "loo"}, {997, "goo"}, {998, "foo"}, {999, "boo"}}) + + // ExecContext returns an error if ANY of the statements can't be executed. This also differs from the behavior of QueryContext. + _, err = conn.ExecContext(ctx, "INSERT into example_table VALUES (100, 'woo'); "+ + "INSERT into example_table VALUES (1, 2, 'too many'); SET @allStatementsExecuted=1;") + require.NotNil(t, err) + if !runTestsAgainstMySQL { + require.Equal(t, "Error 1105: number of values does not match number of columns provided", err.Error()) + } else { + require.Equal(t, "Error 1136 (21S01): Column count doesn't match value count at row 1", err.Error()) + } + + // Assert that the first insert statement was executed before the error occurred + requireResults(t, conn, "SELECT * FROM example_table ORDER BY id;", + [][]any{{100, "woo"}, {996, "loo"}, {997, "goo"}, {998, "foo"}, {999, "boo"}}) + + // Once an error occurs, additional statements are NOT executed. This code tests that the last SET statement + // above was NOT executed. + requireResults(t, conn, "SELECT @allStatementsExecuted;", [][]any{{nil}}) +} + +// TestMultiStatementsQueryContext tests that using QueryContext to run a multi-statement query works as expected and +// matches the behavior of the MySQL driver. +func TestMultiStatementsQueryContext(t *testing.T) { + conn, cleanupFunc := initializeTestDatabaseConnection(t, false) + defer cleanupFunc() + + // QueryContext returns the results from the FIRST statement executed. This differs from the behavior for ExecContext. + ctx := context.Background() + rows, err := conn.QueryContext(ctx, "SELECT 1 FROM dual; SELECT 2 FROM dual; ") + require.NoError(t, err) + require.NoError(t, rows.Err()) + + var v any + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&v)) + require.EqualValues(t, 1, v) + require.False(t, rows.Next()) + + require.True(t, rows.NextResultSet()) + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&v)) + require.EqualValues(t, 2, v) + require.False(t, rows.Next()) + + require.False(t, rows.NextResultSet()) + require.NoError(t, rows.Close()) + + // QueryContext returns an error only if the FIRST statement can't be executed. + rows, err = conn.QueryContext(ctx, "SELECT * FROM no_table; SELECT 42 FROM dual;") + require.Nil(t, rows) + require.NotNil(t, err) + if !runTestsAgainstMySQL { + require.Equal(t, "Error 1146: table not found: no_table", err.Error()) + } else { + require.Equal(t, "Error 1146 (42S02): Table 'testdb.no_table' doesn't exist", err.Error()) + } + + // To access the error for statements after the first statement, you must use rows.Err() + rows, err = conn.QueryContext(ctx, "SELECT 42 FROM dual; SELECT * FROM no_table; SET @allStatementsExecuted=1;") + require.NoError(t, err) + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&v)) + require.EqualValues(t, 42, v) + require.False(t, rows.Next()) + require.False(t, rows.NextResultSet()) + require.NotNil(t, rows.Err()) + if !runTestsAgainstMySQL { + require.Equal(t, "Error 1146: table not found: no_table", rows.Err().Error()) + } else { + require.Equal(t, "Error 1146 (42S02): Table 'testdb.no_table' doesn't exist", rows.Err().Error()) + } + require.NoError(t, rows.Close()) + + // Once an error occurs, additional statements are NOT executed. This code tests that the last SET statement + // above was NOT executed. + rows, err = conn.QueryContext(ctx, "SELECT @allStatementsExecuted;") + require.NoError(t, err) + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&v)) + require.Nil(t, v) + require.NoError(t, rows.Close()) + + // Non-query statements don't return a real result set, so they are skipped over automatically + rows, err = conn.QueryContext(ctx, "SET @notUsed=1; SELECT 42 FROM dual; ") + require.NoError(t, err) + require.NoError(t, rows.Err()) + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&v)) + require.EqualValues(t, 42, v) + require.NoError(t, rows.Close()) + + // Queries that generate an empty result set are NOT skipped over automatically + rows, err = conn.QueryContext(ctx, "CREATE TABLE t (pk int primary key); SELECT * FROM t; SELECT 42 FROM dual;") + require.NoError(t, err) + require.NoError(t, rows.Err()) + require.False(t, rows.Next()) + require.True(t, rows.NextResultSet()) + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&v)) + require.EqualValues(t, 42, v) + require.NoError(t, rows.Close()) + + // If an error occurs between two valid queries, NextResulSet() returns false and exposes the + // error from rows.Err(). + rows, err = conn.QueryContext(ctx, "SELECT * FROM t; SELECT * from t2; SELECT 42 FROM dual;") + require.NoError(t, err) + require.NoError(t, rows.Err()) + require.False(t, rows.Next()) + require.False(t, rows.NextResultSet()) + require.NotNil(t, rows.Err()) + if !runTestsAgainstMySQL { + require.Equal(t, "Error 1146: table not found: t2", rows.Err().Error()) + } else { + require.Equal(t, "Error 1146 (42S02): Table 'testdb.t2' doesn't exist", rows.Err().Error()) + } + require.NoError(t, rows.Close()) + + // If an error occurs before the first real query results set, the error is returned, with no rows + rows, err = conn.QueryContext(ctx, "set @foo='bar'; SELECT * from t2; SELECT 42 FROM dual;") + require.NotNil(t, err) + require.Nil(t, rows) + if !runTestsAgainstMySQL { + require.Equal(t, "Error 1146: table not found: t2", err.Error()) + } else { + require.Equal(t, "Error 1146 (42S02): Table 'testdb.t2' doesn't exist", err.Error()) + } +} + +// TestMultiStatementsWithNoSpaces tests that multistatements are parsed correctly, even when +// there is no space between the statement delimiter and the next statement. +func TestMultiStatementsWithNoSpaces(t *testing.T) { + conn, cleanupFunc := initializeTestDatabaseConnection(t, false) + defer cleanupFunc() + + var v int + ctx := context.Background() + rows, err := conn.QueryContext(ctx, "select 42 from dual;select 43 from dual;") + + // Check the first result set + require.NoError(t, err) + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&v)) + require.Equal(t, 42, v) + require.NoError(t, rows.Err()) + require.False(t, rows.Next()) + + // Check the second result set + require.True(t, rows.NextResultSet()) + require.NoError(t, err) + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&v)) + require.Equal(t, 43, v) + require.NoError(t, rows.Err()) + require.False(t, rows.Next()) + require.NoError(t, rows.Close()) +} + +// TestMultiStatementsWithEmptyStatements tests that any empty statements in a multistatement query are skipped over. +// This includes statements that are entirely empty, as well as statements that contain only comments. +func TestMultiStatementsWithEmptyStatements(t *testing.T) { + conn, cleanupFunc := initializeTestDatabaseConnection(t, false) + defer cleanupFunc() + + var v int + ctx := context.Background() + + // Test that empty statements don't return errors and don't return result sets + rows, err := conn.QueryContext(ctx, "select 42 from dual; # This is an empty statement") + require.NoError(t, err) + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&v)) + require.Equal(t, 42, v) + require.NoError(t, rows.Err()) + require.False(t, rows.Next()) + require.False(t, rows.NextResultSet()) + require.NoError(t, rows.Close()) + + // Test another form of empty statement + rows, err = conn.QueryContext(ctx, "select 42 from dual; ; ; ; select 24 from dual; ;") + require.NoError(t, err) + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&v)) + require.Equal(t, 42, v) + require.NoError(t, rows.Err()) + require.False(t, rows.Next()) + + // NOTE: The MySQL driver does not allow moving past empty statements to the next result set + if !runTestsAgainstMySQL { + require.True(t, rows.NextResultSet()) + require.NoError(t, err) + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(&v)) + require.Equal(t, 24, v) + require.NoError(t, rows.Err()) + require.False(t, rows.Next()) + } + + require.False(t, rows.NextResultSet()) + require.NoError(t, rows.Close()) +} + func TestMultiStatementsStoredProc(t *testing.T) { conn, cleanupFunc := initializeTestDatabaseConnection(t, false) defer cleanupFunc() @@ -64,6 +353,19 @@ func TestMultiStatementsStoredProc(t *testing.T) { ctx := context.Background() rows, err := conn.QueryContext(ctx, "create procedure p() begin select 1; end; call p(); call p(); call p();") require.NoError(t, err) + + // NOTE: Because the first statement is not a query and doesn't have a real result set, the current result set + // is automatically positioned at the second statement. + for rows.Next() { + var i int + err = rows.Scan(&i) + require.NoError(t, err) + require.Equal(t, 1, i) + } + require.NoError(t, rows.Err()) + + // Advance to the third result set and check its rows + require.True(t, rows.NextResultSet()) for rows.Next() { var i int err = rows.Scan(&i) @@ -71,6 +373,17 @@ func TestMultiStatementsStoredProc(t *testing.T) { require.Equal(t, 1, i) } require.NoError(t, rows.Err()) + + // Advance to the fourth result set and check its rows + require.True(t, rows.NextResultSet()) + for rows.Next() { + var i int + err = rows.Scan(&i) + require.NoError(t, err) + require.Equal(t, 1, i) + } + require.NoError(t, rows.Err()) + require.NoError(t, rows.Close()) } @@ -86,6 +399,9 @@ func TestMultiStatementsTrigger(t *testing.T) { rows, err := conn.QueryContext(ctx, "create trigger trig before insert on t for each row begin set new.j = new.j * 100; end; insert into t values (1, 2); select * from t;") require.NoError(t, err) + + // NOTE: Because the first statement is not a query and doesn't have a real result set, the current result set + // is automatically positioned at the second statement. for rows.Next() { var i, j int err = rows.Scan(&i, &j) @@ -205,11 +521,11 @@ insert into testtable values ('b', 'a,c', '{"key": 42}', 'data', 'text', Point(5 ptrs[i] = &vals[i] } require.NoError(t, row.Scan(ptrs...)) - require.Equal(t, "b", vals[0]) - require.Equal(t, "a,c", vals[1]) - require.Equal(t, `{"key": 42}`, vals[2]) - require.Equal(t, []byte(`data`), vals[3]) - require.Equal(t, "text", vals[4]) + require.EqualValues(t, "b", vals[0]) + require.EqualValues(t, "a,c", vals[1]) + require.EqualValues(t, `{"key": 42}`, vals[2]) + require.EqualValues(t, []byte(`data`), vals[3]) + require.EqualValues(t, "text", vals[4]) require.IsType(t, []byte(nil), vals[5]) require.IsType(t, time.Time{}, vals[6]) } @@ -241,10 +557,30 @@ func initializeTestDatabaseConnection(t *testing.T, clientFoundRows bool) (conn require.NoError(t, err) require.NoError(t, db.PingContext(ctx)) + if runTestsAgainstMySQL { + dsn := mysqlDsn + if clientFoundRows { + dsn += "&clientFoundRows=true" + } + db, err = sql.Open("mysql", dsn) + require.NoError(t, err) + require.NoError(t, db.PingContext(ctx)) + } + conn, err = db.Conn(ctx) require.NoError(t, err) - res, err := conn.ExecContext(ctx, "create database testdb") + res, err := conn.ExecContext(ctx, "drop database if exists testdb") + require.NoError(t, err) + _, err = res.RowsAffected() + require.NoError(t, err) + + res, err = conn.ExecContext(ctx, "create database testdb") + require.NoError(t, err) + _, err = res.RowsAffected() + require.NoError(t, err) + + res, err = conn.ExecContext(ctx, "use testdb") require.NoError(t, err) _, err = res.RowsAffected() require.NoError(t, err) @@ -252,6 +588,30 @@ func initializeTestDatabaseConnection(t *testing.T, clientFoundRows bool) (conn return conn, cleanUpFunc } +// requireResults uses |conn| to run the specified |query| and asserts that the results +// match |expected|. If any differences are encountered, the current test fails. +func requireResults(t *testing.T, conn *sql.Conn, query string, expected [][]any) { + ctx := context.Background() + vals := make([]any, len(expected[0])) + + rows, err := conn.QueryContext(ctx, query) + require.NoError(t, err) + + for _, expectedRow := range expected { + for i := range vals { + vals[i] = &vals[i] + } + require.True(t, rows.Next()) + require.NoError(t, rows.Scan(vals...)) + for i, expectedVal := range expectedRow { + require.EqualValues(t, expectedVal, vals[i]) + } + } + + require.False(t, rows.Next()) + require.NoError(t, rows.Close()) +} + func encodeDir(dir string) string { // encodeDir translate a given path to a URL compatible path, mostly for windows compatibility if os.PathSeparator == '\\' { diff --git a/statement.go b/statement.go index f4e480a..60ee1ab 100644 --- a/statement.go +++ b/statement.go @@ -2,24 +2,93 @@ package embedded import ( "database/sql/driver" - - "github.com/dolthub/vitess/go/sqltypes" + "strconv" "github.com/dolthub/dolt/go/cmd/dolt/commands/engine" gms "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" + "github.com/dolthub/vitess/go/sqltypes" querypb "github.com/dolthub/vitess/go/vt/proto/query" - - "strconv" ) -var _ driver.Stmt = (*doltStmt)(nil) +// doltMultiStmt represents a collection of statements to be executed against a +// Dolt database. +type doltMultiStmt struct { + stmts []*doltStmt +} + +var _ driver.Stmt = (*doltMultiStmt)(nil) + +func (d doltMultiStmt) Close() error { + var retErr error + for _, stmt := range d.stmts { + if err := stmt.Close(); err != nil { + retErr = err + } + } + + return retErr +} + +func (d doltMultiStmt) NumInput() int { + return -1 +} + +func (d doltMultiStmt) Exec(args []driver.Value) (result driver.Result, err error) { + for _, stmt := range d.stmts { + result, err = stmt.Exec(args) + if err != nil { + // If any error occurs, return the error and don't execute any more statements + return nil, err + } + } + + // Otherwise, return the last result, to match the MySQL driver's behavior + return result, nil +} + +func (d doltMultiStmt) Query(args []driver.Value) (driver.Rows, error) { + var multiResultSet doltMultiRows + for _, stmt := range d.stmts { + rows, err := stmt.Query(args) + if err != nil { + // If an error occurs, we don't execute any more statements in the multistatement query. Instead, we + // capture the error in a doltRows instance, so that rows.NextResultSet() will return the error when + // the caller requests that result set. This is to match the MySQL driver's behavior. + multiResultSet.rowSets = append(multiResultSet.rowSets, &doltRows{err: err}) + break + } else { + multiResultSet.rowSets = append(multiResultSet.rowSets, rows.(*doltRows)) + } + } + + // Position the current result set index at the first statement that is a query, with a real result set. In + // other words, skip over any statements that don't actually return results sets (e.g. INSERT or DDL statements). + for ; multiResultSet.currentRowSet < len(multiResultSet.rowSets); multiResultSet.currentRowSet++ { + if multiResultSet.rowSets[multiResultSet.currentRowSet].isQueryResultSet || + multiResultSet.rowSets[multiResultSet.currentRowSet].err != nil { + break + } + } + + // If an error occurred before any query result set, go ahead and return the error, without any result set. + if multiResultSet.currentRowSet < len(multiResultSet.rowSets) && + multiResultSet.rowSets[multiResultSet.currentRowSet].err != nil { + return nil, multiResultSet.rowSets[multiResultSet.currentRowSet].err + } else { + return &multiResultSet, nil + } +} +// doltStmt represents a single statement to be executed against a Dolt database. type doltStmt struct { se *engine.SqlEngine gmsCtx *gms.Context query string } +var _ driver.Stmt = (*doltStmt)(nil) + // Close closes the statement. func (stmt *doltStmt) Close() error { return nil @@ -83,14 +152,42 @@ func (stmt *doltStmt) Query(args []driver.Value) (driver.Rows, error) { } else { sch, rowIter, err = stmt.se.Query(stmt.gmsCtx, stmt.query) } - if err != nil { return nil, translateError(err) } + // Wrap the result iterator in a peekableRowIter and call Peek() to read the first row from the result iterator. + // This is necessary for insert operations, since the insert happens inside the result iterator logic. Without + // calling this now, insert statements and some DML statements (e.g. CREATE PROCEDURE) would not be executed yet, + // and future statements in a multi-statement query that depend on those results would fail. + // If an error does occur, we want that error to be returned in the Next() codepath, not here. + peekIter := peekableRowIter{iter: rowIter} + row, _ := peekIter.Peek(stmt.gmsCtx) + return &doltRows{ - sch: sch, - rowIter: rowIter, - gmsCtx: stmt.gmsCtx, + sch: sch, + rowIter: &peekIter, + gmsCtx: stmt.gmsCtx, + isQueryResultSet: isQueryResultSet(row), }, nil } + +// isQueryResultSet returns true if the specified |row| is a valid result set for a query. If row only contains +// one column and is an OkResult, or if row has zero columns, then the statement that generated this row was not +// a query. +func isQueryResultSet(row gms.Row) bool { + // If row is nil, return true since this could still be a valid, empty result set. + if row == nil { + return true + } + + if len(row) == 1 { + if _, ok := row[0].(types.OkResult); ok { + return false + } + } else if len(row) == 0 { + return false + } + + return true +}