diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c0054952..6354d30aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Remove global NewBatch function (CASSGO-15) +- Change Batch API to be consistent with Query() (CASSGO-7) + ### Fixed - Retry policy now takes into account query idempotency (CASSGO-27) diff --git a/batch_test.go b/batch_test.go index 25f8c8364..44b52663f 100644 --- a/batch_test.go +++ b/batch_test.go @@ -47,9 +47,9 @@ func TestBatch_Errors(t *testing.T) { t.Fatal(err) } - b := session.NewBatch(LoggedBatch) - b.Query("SELECT * FROM batch_errors WHERE id=2 AND val=?", nil) - if err := session.ExecuteBatch(b); err == nil { + b := session.Batch(LoggedBatch) + b = b.Query("SELECT * FROM gocql_test.batch_errors WHERE id=2 AND val=?", nil) + if err := b.Exec(); err == nil { t.Fatal("expected to get error for invalid query in batch") } } @@ -68,15 +68,17 @@ func TestBatch_WithTimestamp(t *testing.T) { micros := time.Now().UnixNano()/1e3 - 1000 - b := session.NewBatch(LoggedBatch) + b := session.Batch(LoggedBatch) b.WithTimestamp(micros) - b.Query("INSERT INTO batch_ts (id, val) VALUES (?, ?)", 1, "val") - if err := session.ExecuteBatch(b); err != nil { + b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 1, "val") + b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 2, "val") + + if err := b.Exec(); err != nil { t.Fatal(err) } var storedTs int64 - if err := session.Query(`SELECT writetime(val) FROM batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil { + if err := session.Query(`SELECT writetime(val) FROM gocql_test.batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil { t.Fatal(err) } diff --git a/cassandra_test.go b/cassandra_test.go index 3b0c61053..ec6969190 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -45,7 +45,7 @@ import ( "time" "unicode" - inf "gopkg.in/inf.v0" + "gopkg.in/inf.v0" ) func TestEmptyHosts(t *testing.T) { @@ -454,7 +454,7 @@ func TestCAS(t *testing.T) { t.Fatal("truncate:", err) } - successBatch := session.NewBatch(LoggedBatch) + successBatch := session.Batch(LoggedBatch) successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified) if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil { t.Fatal("insert:", err) @@ -462,7 +462,7 @@ func TestCAS(t *testing.T) { t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS) } - successBatch = session.NewBatch(LoggedBatch) + successBatch = session.Batch(LoggedBatch) successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title+"_foo", revid, modified) casMap := make(map[string]interface{}) if applied, _, err := session.MapExecuteBatchCAS(successBatch, casMap); err != nil { @@ -471,7 +471,7 @@ func TestCAS(t *testing.T) { t.Fatal("insert should have been applied") } - failBatch := session.NewBatch(LoggedBatch) + failBatch := session.Batch(LoggedBatch) failBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified) if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil { t.Fatal("insert:", err) @@ -479,14 +479,14 @@ func TestCAS(t *testing.T) { t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS) } - insertBatch := session.NewBatch(LoggedBatch) + insertBatch := session.Batch(LoggedBatch) insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 2c3af400-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))") insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 3e4ad2f1-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))") if err := session.ExecuteBatch(insertBatch); err != nil { t.Fatal("insert:", err) } - failBatch = session.NewBatch(LoggedBatch) + failBatch = session.Batch(LoggedBatch) failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=2c3af400-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());") failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());") if applied, iter, err := session.ExecuteBatchCAS(failBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil { @@ -611,7 +611,7 @@ func TestBatch(t *testing.T) { t.Fatal("create table:", err) } - batch := session.NewBatch(LoggedBatch) + batch := session.Batch(LoggedBatch) for i := 0; i < 100; i++ { batch.Query(`INSERT INTO batch_table (id) VALUES (?)`, i) } @@ -643,9 +643,9 @@ func TestUnpreparedBatch(t *testing.T) { var batch *Batch if session.cfg.ProtoVersion == 2 { - batch = session.NewBatch(CounterBatch) + batch = session.Batch(CounterBatch) } else { - batch = session.NewBatch(UnloggedBatch) + batch = session.Batch(UnloggedBatch) } for i := 0; i < 100; i++ { @@ -684,7 +684,7 @@ func TestBatchLimit(t *testing.T) { t.Fatal("create table:", err) } - batch := session.NewBatch(LoggedBatch) + batch := session.Batch(LoggedBatch) for i := 0; i < 65537; i++ { batch.Query(`INSERT INTO batch_table2 (id) VALUES (?)`, i) } @@ -738,7 +738,7 @@ func TestTooManyQueryArgs(t *testing.T) { t.Fatal("'`SELECT * FROM too_many_query_args WHERE id = ?`, 1, 2' should return an error") } - batch := session.NewBatch(UnloggedBatch) + batch := session.Batch(UnloggedBatch) batch.Query("INSERT INTO too_many_query_args (id, value) VALUES (?, ?)", 1, 2, 3) err = session.ExecuteBatch(batch) @@ -770,7 +770,7 @@ func TestNotEnoughQueryArgs(t *testing.T) { t.Fatal("'`SELECT * FROM not_enough_query_args WHERE id = ? and cluster = ?`, 1' should return an error") } - batch := session.NewBatch(UnloggedBatch) + batch := session.Batch(UnloggedBatch) batch.Query("INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)", 1, 2) err = session.ExecuteBatch(batch) @@ -1392,7 +1392,7 @@ func TestBatchQueryInfo(t *testing.T) { return values, nil } - batch := session.NewBatch(LoggedBatch) + batch := session.Batch(LoggedBatch) batch.Bind("INSERT INTO batch_query_info (id, cluster, value) VALUES (?, ?,?)", write) if err := session.ExecuteBatch(batch); err != nil { @@ -1520,7 +1520,7 @@ func TestPrepare_ReprepareBatch(t *testing.T) { } stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch") - batch := session.NewBatch(UnloggedBatch) + batch := session.Batch(UnloggedBatch) batch.Query(stmt, "bar") if err := conn.executeBatch(ctx, batch).Close(); err != nil { t.Fatalf("Failed to execute query for reprepare statement: %v", err) @@ -1904,7 +1904,7 @@ func TestBatchStats(t *testing.T) { t.Fatalf("failed to create table with error '%v'", err) } - b := session.NewBatch(LoggedBatch) + b := session.Batch(LoggedBatch) b.Query("INSERT INTO batchStats (id) VALUES (?)", 1) b.Query("INSERT INTO batchStats (id) VALUES (?)", 2) @@ -1947,7 +1947,7 @@ func TestBatchObserve(t *testing.T) { var observedBatch *observation - batch := session.NewBatch(LoggedBatch) + batch := session.Batch(LoggedBatch) batch.Observer(funcBatchObserver(func(ctx context.Context, o ObservedBatch) { if observedBatch != nil { t.Fatal("batch observe called more than once") @@ -3286,7 +3286,7 @@ func TestUnsetColBatch(t *testing.T) { t.Fatalf("failed to create table with error '%v'", err) } - b := session.NewBatch(LoggedBatch) + b := session.Batch(LoggedBatch) b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, 1, UnsetValue) b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, UnsetValue, "") b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 2, 2, UnsetValue) diff --git a/doc.go b/doc.go index 236b55e2f..109a85034 100644 --- a/doc.go +++ b/doc.go @@ -310,7 +310,7 @@ // # Batches // // The CQL protocol supports sending batches of DML statements (INSERT/UPDATE/DELETE) and so does gocql. -// Use Session.NewBatch to create a new batch and then fill-in details of individual queries. +// Use Session.Batch to create a new batch and then fill-in details of individual queries. // Then execute the batch with Session.ExecuteBatch. // // Logged batches ensure atomicity, either all or none of the operations in the batch will succeed, but they have diff --git a/example_batch_test.go b/example_batch_test.go index 2695e48bd..b27085ccc 100644 --- a/example_batch_test.go +++ b/example_batch_test.go @@ -29,7 +29,7 @@ import ( "fmt" "log" - gocql "github.com/gocql/gocql" + "github.com/gocql/gocql" ) // Example_batch demonstrates how to execute a batch of statements. @@ -49,7 +49,7 @@ func Example_batch() { ctx := context.Background() - b := session.NewBatch(gocql.UnloggedBatch).WithContext(ctx) + b := session.Batch(gocql.UnloggedBatch).WithContext(ctx) b.Entries = append(b.Entries, gocql.BatchEntry{ Stmt: "INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", Args: []interface{}{1, 2, "1.2"}, @@ -60,11 +60,19 @@ func Example_batch() { Args: []interface{}{1, 3, "1.3"}, Idempotent: true, }) + err = session.ExecuteBatch(b) if err != nil { log.Fatal(err) } + err = b.Query("INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", 1, 4, "1.4"). + Query("INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", 1, 5, "1.5"). + Exec() + if err != nil { + log.Fatal(err) + } + scanner := session.Query("SELECT pk, ck, description FROM example.batches").Iter().Scanner() for scanner.Next() { var pk, ck int32 @@ -77,4 +85,6 @@ func Example_batch() { } // 1 2 1.2 // 1 3 1.3 + // 1 4 1.4 + // 1 5 1.5 } diff --git a/example_lwt_batch_test.go b/example_lwt_batch_test.go index 916367eb3..c3cc8383d 100644 --- a/example_lwt_batch_test.go +++ b/example_lwt_batch_test.go @@ -29,7 +29,7 @@ import ( "fmt" "log" - gocql "github.com/gocql/gocql" + "github.com/gocql/gocql" ) // ExampleSession_MapExecuteBatchCAS demonstrates how to execute a batch lightweight transaction. @@ -62,7 +62,7 @@ func ExampleSession_MapExecuteBatchCAS() { } executeBatch := func(ck2Version int) { - b := session.NewBatch(gocql.LoggedBatch) + b := session.Batch(gocql.LoggedBatch) b.Entries = append(b.Entries, gocql.BatchEntry{ Stmt: "UPDATE my_lwt_batch_table SET value=? WHERE pk=? AND ck=? IF version=?", Args: []interface{}{"b", "pk1", "ck1", 1}, diff --git a/integration_test.go b/integration_test.go index 3622dfbd6..61ffbf504 100644 --- a/integration_test.go +++ b/integration_test.go @@ -218,7 +218,7 @@ func TestCustomPayloadMessages(t *testing.T) { iter.Close() // Batch Message - b := session.NewBatch(LoggedBatch) + b := session.Batch(LoggedBatch) b.CustomPayload = customPayload b.Query("INSERT INTO testCustomPayloadMessages(id,value) VALUES(1, 1)") if err := session.ExecuteBatch(b); err != nil { diff --git a/session.go b/session.go index b884735c2..d04a13672 100644 --- a/session.go +++ b/session.go @@ -731,6 +731,13 @@ func (b *Batch) execute(ctx context.Context, conn *Conn) *Iter { return conn.executeBatch(ctx, b) } +// Exec executes a batch operation and returns nil if successful +// otherwise an error is returned describing the failure. +func (b *Batch) Exec() error { + iter := b.session.executeBatch(b) + return iter.Close() +} + func (s *Session) executeBatch(batch *Batch) *Iter { // fail fast if s.Closed() { @@ -1748,7 +1755,14 @@ type Batch struct { } // NewBatch creates a new batch operation using defaults defined in the cluster +// +// Deprecated: use session.Batch instead func (s *Session) NewBatch(typ BatchType) *Batch { + return s.Batch(typ) +} + +// Batch creates a new batch operation using defaults defined in the cluster +func (s *Session) Batch(typ BatchType) *Batch { s.mu.RLock() batch := &Batch{ Type: typ, @@ -1848,8 +1862,9 @@ func (b *Batch) SpeculativeExecutionPolicy(sp SpeculativeExecutionPolicy) *Batch } // Query adds the query to the batch operation -func (b *Batch) Query(stmt string, args ...interface{}) { +func (b *Batch) Query(stmt string, args ...interface{}) *Batch { b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, Args: args}) + return b } // Bind adds the query to the batch operation and correlates it with a binding callback diff --git a/session_test.go b/session_test.go index 0319a8a4c..850e88531 100644 --- a/session_test.go +++ b/session_test.go @@ -96,7 +96,7 @@ func TestSessionAPI(t *testing.T) { t.Fatalf("expected itr.err to be '%v', got '%v'", ErrNoConnections, itr.err) } - testBatch := s.NewBatch(LoggedBatch) + testBatch := s.Batch(LoggedBatch) testBatch.Query("test") err := s.ExecuteBatch(testBatch) @@ -219,7 +219,7 @@ func TestBatchBasicAPI(t *testing.T) { s.pool = cfg.PoolConfig.buildPool(s) // Test UnloggedBatch - b := s.NewBatch(UnloggedBatch) + b := s.Batch(UnloggedBatch) if b.Type != UnloggedBatch { t.Fatalf("expceted batch.Type to be '%v', got '%v'", UnloggedBatch, b.Type) } else if b.rt != cfg.RetryPolicy { @@ -227,7 +227,7 @@ func TestBatchBasicAPI(t *testing.T) { } // Test LoggedBatch - b = s.NewBatch(LoggedBatch) + b = s.Batch(LoggedBatch) if b.Type != LoggedBatch { t.Fatalf("expected batch.Type to be '%v', got '%v'", LoggedBatch, b.Type) }