diff --git a/integration-tests/shutdown_test.go b/integration-tests/shutdown_test.go index 6b2db410..b6b44cda 100644 --- a/integration-tests/shutdown_test.go +++ b/integration-tests/shutdown_test.go @@ -67,6 +67,9 @@ func TestShutdownInFlightRequests(t *testing.T) { testSetup.Origin.Prime( simulacron.WhenQuery("SELECT * FROM test2", simulacron.NewWhenQueryOptions()). ThenSuccess().WithDelay(3 * time.Second)) + testSetup.Origin.Prime( + simulacron.WhenQuery("SELECT * FROM test3", simulacron.NewWhenQueryOptions()). + ThenSuccess().WithDelay(4 * time.Second)) queryMsg1 := &message.Query{ Query: "SELECT * FROM test1", @@ -78,21 +81,24 @@ func TestShutdownInFlightRequests(t *testing.T) { Options: nil, } + queryMsg3 := &message.Query{ + Query: "SELECT * FROM test3", + Options: nil, + } + beginTimestamp := time.Now() reqFrame := frame.NewFrame(primitive.ProtocolVersion4, 2, queryMsg1) inflightRequest, err := cqlConn.Send(reqFrame) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.Nil(t, err) reqFrame2 := frame.NewFrame(primitive.ProtocolVersion4, 3, queryMsg2) inflightRequest2, err := cqlConn.Send(reqFrame2) + require.Nil(t, err) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + reqFrame3 := frame.NewFrame(primitive.ProtocolVersion4, 4, queryMsg3) + inflightRequest3, err := cqlConn.Send(reqFrame3) + require.Nil(t, err) time.Sleep(1 * time.Second) @@ -119,15 +125,12 @@ func TestShutdownInFlightRequests(t *testing.T) { default: } - reqFrame3 := frame.NewFrame(primitive.ProtocolVersion4, 4, queryMsg1) - inflightRequest3, err := cqlConn.Send(reqFrame3) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + reqFrame4 := frame.NewFrame(primitive.ProtocolVersion4, 5, queryMsg1) + inflightRequest4, err := cqlConn.Send(reqFrame4) + require.Nil(t, err) select { - case rsp := <-inflightRequest3.Incoming(): + case rsp := <-inflightRequest4.Incoming(): require.Equal(t, primitive.OpCodeError, rsp.Header.OpCode) _, ok := rsp.Body.Message.(*message.Overloaded) require.True(t, ok) @@ -136,14 +139,24 @@ func TestShutdownInFlightRequests(t *testing.T) { } select { - case rsp := <-inflightRequest2.Incoming(): + case rsp, ok := <-inflightRequest2.Incoming(): + require.True(t, ok) require.Equal(t, primitive.OpCodeResult, rsp.Header.OpCode) case <-time.After(15 * time.Second): t.Fatalf("test timed out after 15 seconds") } - // 2 seconds instead of 3 just in case there is a time precision issue - require.GreaterOrEqual(t, time.Now().Sub(beginTimestamp).Nanoseconds(), (2 * time.Second).Nanoseconds()) + select { + case rsp, ok := <-inflightRequest3.Incoming(): + if ok { // ignore if last request's channel is closed before we read from it + require.Equal(t, primitive.OpCodeResult, rsp.Header.OpCode) + } + case <-time.After(15 * time.Second): + t.Fatalf("test timed out after 15 seconds") + } + + // 3 seconds instead of 4 just in case there is a time precision issue + require.GreaterOrEqual(t, time.Now().Sub(beginTimestamp).Nanoseconds(), (3 * time.Second).Nanoseconds()) select { case <-shutdownComplete: