Skip to content

Commit

Permalink
ScatterCon results observer tests
Browse files Browse the repository at this point in the history
Signed-off-by: Rafer Hazen <[email protected]>
  • Loading branch information
rafer committed Aug 22, 2024
1 parent 972f77c commit e6dc653
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions go/vt/vtgate/legacy_scatter_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,14 @@ func executeOnShardsReturnsErr(t *testing.T, ctx context.Context, res *srvtopo.R
return vterrors.Aggregate(errs)
}

type recordingResultsObserver struct {
recorded []*sqltypes.Result
}

func (o *recordingResultsObserver) observe(result *sqltypes.Result) {
o.recorded = append(o.recorded, result)
}

func TestMultiExecs(t *testing.T) {
ctx := utils.LeakCheckContext(t)
createSandbox("TestMultiExecs")
Expand Down Expand Up @@ -409,9 +417,17 @@ func TestMultiExecs(t *testing.T) {
},
},
}
results := []*sqltypes.Result{
{Info: "r0"},
{Info: "r1"},
}
sbc0.SetResults(results[0:1])
sbc1.SetResults(results[1:2])

observer := recordingResultsObserver{}

session := NewSafeSession(&vtgatepb.Session{})
_, err := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, false, false, nullResultsObserver{})
_, err := sc.ExecuteMultiShard(ctx, nil, rss, queries, session, false, false, &observer)
require.NoError(t, vterrors.Aggregate(err))
if len(sbc0.Queries) == 0 || len(sbc1.Queries) == 0 {
t.Fatalf("didn't get expected query")
Expand All @@ -428,8 +444,12 @@ func TestMultiExecs(t *testing.T) {
if !reflect.DeepEqual(sbc1.Queries[0].BindVariables, wantVars1) {
t.Errorf("got %+v, want %+v", sbc0.Queries[0].BindVariables, wantVars1)
}
assert.ElementsMatch(t, results, observer.recorded)

sbc0.Queries = nil
sbc1.Queries = nil
sbc0.SetResults(results[0:1])
sbc1.SetResults(results[1:2])

rss = []*srvtopo.ResolvedShard{
{
Expand All @@ -455,15 +475,18 @@ func TestMultiExecs(t *testing.T) {
"bv1": sqltypes.Int64BindVariable(1),
},
}

observer = recordingResultsObserver{}
_ = sc.StreamExecuteMulti(ctx, nil, "query", rss, bvs, session, false /* autocommit */, func(*sqltypes.Result) error {
return nil
}, nullResultsObserver{})
}, &observer)
if !reflect.DeepEqual(sbc0.Queries[0].BindVariables, wantVars0) {
t.Errorf("got %+v, want %+v", sbc0.Queries[0].BindVariables, wantVars0)
}
if !reflect.DeepEqual(sbc1.Queries[0].BindVariables, wantVars1) {
t.Errorf("got %+v, want %+v", sbc0.Queries[0].BindVariables, wantVars1)
}
assert.ElementsMatch(t, results, observer.recorded)
}

func TestScatterConnSingleDB(t *testing.T) {
Expand Down

0 comments on commit e6dc653

Please sign in to comment.