diff --git a/Makefile b/Makefile index fc9660e..b95a4fb 100644 --- a/Makefile +++ b/Makefile @@ -35,7 +35,7 @@ endef $(eval $(call makemock, $$(DBMIGRATE_PATH), Driver, dbmigratemocks)) $(eval $(call makemock, pkg/httpserver, GoHTTPServer, httpservermocks)) $(eval $(call makemock, pkg/auth, Plugin, authmocks)) -$(eval $(call makemock, pkg/wsserver, WebSocketChannels, wsservermocks)) +$(eval $(call makemock, pkg/wsserver, Protocol, wsservermocks)) $(eval $(call makemock, pkg/wsserver, WebSocketServer, wsservermocks)) $(eval $(call makemock, pkg/dbsql, CRUD, crudmocks)) diff --git a/examples/ffpubsub.go b/examples/ffpubsub.go index 398617a..6ff8055 100644 --- a/examples/ffpubsub.go +++ b/examples/ffpubsub.go @@ -128,6 +128,8 @@ func setup(ctx context.Context) (pubSubESManager, *inMemoryStream, func()) { // Use SQLite in-memory DB conf := config.RootSection("ffpubsub") eventstreams.InitConfig(conf) + wsConf := conf.SubSection("ws") + wsserver.InitConfig(wsConf) dbConf := conf.SubSection("sqlite") dbsql.InitSQLiteConfig(dbConf) dbConf.Set(dbsql.SQLConfMigrationsAuto, true) @@ -139,7 +141,7 @@ func setup(ctx context.Context) (pubSubESManager, *inMemoryStream, func()) { sql, err := dbsql.NewSQLiteProvider(ctx, dbConf) assertNoError(err) - wsServer := wsserver.NewWebSocketServer(ctx) + wsServer := wsserver.NewWebSocketServer(ctx, wsserver.GenerateConfig(wsConf)) server := httptest.NewServer(http.HandlerFunc(wsServer.Handler)) u, err := url.Parse(server.URL) assertNoError(err) diff --git a/mocks/authmocks/plugin.go b/mocks/authmocks/plugin.go index 7fbce24..4ff0804 100644 --- a/mocks/authmocks/plugin.go +++ b/mocks/authmocks/plugin.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.38.0. DO NOT EDIT. +// Code generated by mockery v2.40.1. DO NOT EDIT. package authmocks diff --git a/mocks/crudmocks/crud.go b/mocks/crudmocks/crud.go index 23aa474..5961966 100644 --- a/mocks/crudmocks/crud.go +++ b/mocks/crudmocks/crud.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.38.0. DO NOT EDIT. +// Code generated by mockery v2.40.1. DO NOT EDIT. package crudmocks @@ -478,6 +478,24 @@ func (_m *CRUD[T]) Scoped(scope squirrel.Eq) dbsql.CRUD[T] { return r0 } +// TableAlias provides a mock function with given fields: +func (_m *CRUD[T]) TableAlias() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for TableAlias") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + // Update provides a mock function with given fields: ctx, id, update, hooks func (_m *CRUD[T]) Update(ctx context.Context, id string, update ffapi.Update, hooks ...dbsql.PostCompletionHook) error { _va := make([]interface{}, len(hooks)) diff --git a/mocks/dbmigratemocks/driver.go b/mocks/dbmigratemocks/driver.go index 18b589e..a61c374 100644 --- a/mocks/dbmigratemocks/driver.go +++ b/mocks/dbmigratemocks/driver.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.38.0. DO NOT EDIT. +// Code generated by mockery v2.40.1. DO NOT EDIT. package dbmigratemocks diff --git a/mocks/httpservermocks/go_http_server.go b/mocks/httpservermocks/go_http_server.go index 7ab9c91..f4a73df 100644 --- a/mocks/httpservermocks/go_http_server.go +++ b/mocks/httpservermocks/go_http_server.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.38.0. DO NOT EDIT. +// Code generated by mockery v2.40.1. DO NOT EDIT. package httpservermocks diff --git a/mocks/wsservermocks/protocol.go b/mocks/wsservermocks/protocol.go new file mode 100644 index 0000000..10863ae --- /dev/null +++ b/mocks/wsservermocks/protocol.go @@ -0,0 +1,64 @@ +// Code generated by mockery v2.40.1. DO NOT EDIT. + +package wsservermocks + +import ( + context "context" + + wsserver "github.com/hyperledger/firefly-common/pkg/wsserver" + mock "github.com/stretchr/testify/mock" +) + +// Protocol is an autogenerated mock type for the Protocol type +type Protocol struct { + mock.Mock +} + +// Broadcast provides a mock function with given fields: ctx, stream, payload +func (_m *Protocol) Broadcast(ctx context.Context, stream string, payload interface{}) { + _m.Called(ctx, stream, payload) +} + +// RoundTrip provides a mock function with given fields: ctx, stream, payload +func (_m *Protocol) RoundTrip(ctx context.Context, stream string, payload wsserver.WSBatch) (*wsserver.WebSocketCommandMessage, error) { + ret := _m.Called(ctx, stream, payload) + + if len(ret) == 0 { + panic("no return value specified for RoundTrip") + } + + var r0 *wsserver.WebSocketCommandMessage + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, wsserver.WSBatch) (*wsserver.WebSocketCommandMessage, error)); ok { + return rf(ctx, stream, payload) + } + if rf, ok := ret.Get(0).(func(context.Context, string, wsserver.WSBatch) *wsserver.WebSocketCommandMessage); ok { + r0 = rf(ctx, stream, payload) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*wsserver.WebSocketCommandMessage) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, wsserver.WSBatch) error); ok { + r1 = rf(ctx, stream, payload) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewProtocol creates a new instance of Protocol. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewProtocol(t interface { + mock.TestingT + Cleanup(func()) +}) *Protocol { + mock := &Protocol{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/mocks/wsservermocks/web_socket_channels.go b/mocks/wsservermocks/web_socket_channels.go deleted file mode 100644 index 1123e65..0000000 --- a/mocks/wsservermocks/web_socket_channels.go +++ /dev/null @@ -1,68 +0,0 @@ -// Code generated by mockery v2.38.0. DO NOT EDIT. - -package wsservermocks - -import ( - wsserver "github.com/hyperledger/firefly-common/pkg/wsserver" - mock "github.com/stretchr/testify/mock" -) - -// WebSocketChannels is an autogenerated mock type for the WebSocketChannels type -type WebSocketChannels struct { - mock.Mock -} - -// GetChannels provides a mock function with given fields: streamName -func (_m *WebSocketChannels) GetChannels(streamName string) (chan<- interface{}, chan<- interface{}, <-chan *wsserver.WebSocketCommandMessageOrError) { - ret := _m.Called(streamName) - - if len(ret) == 0 { - panic("no return value specified for GetChannels") - } - - var r0 chan<- interface{} - var r1 chan<- interface{} - var r2 <-chan *wsserver.WebSocketCommandMessageOrError - if rf, ok := ret.Get(0).(func(string) (chan<- interface{}, chan<- interface{}, <-chan *wsserver.WebSocketCommandMessageOrError)); ok { - return rf(streamName) - } - if rf, ok := ret.Get(0).(func(string) chan<- interface{}); ok { - r0 = rf(streamName) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(chan<- interface{}) - } - } - - if rf, ok := ret.Get(1).(func(string) chan<- interface{}); ok { - r1 = rf(streamName) - } else { - if ret.Get(1) != nil { - r1 = ret.Get(1).(chan<- interface{}) - } - } - - if rf, ok := ret.Get(2).(func(string) <-chan *wsserver.WebSocketCommandMessageOrError); ok { - r2 = rf(streamName) - } else { - if ret.Get(2) != nil { - r2 = ret.Get(2).(<-chan *wsserver.WebSocketCommandMessageOrError) - } - } - - return r0, r1, r2 -} - -// NewWebSocketChannels creates a new instance of WebSocketChannels. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewWebSocketChannels(t interface { - mock.TestingT - Cleanup(func()) -}) *WebSocketChannels { - mock := &WebSocketChannels{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/mocks/wsservermocks/web_socket_server.go b/mocks/wsservermocks/web_socket_server.go index cfcaece..45c6562 100644 --- a/mocks/wsservermocks/web_socket_server.go +++ b/mocks/wsservermocks/web_socket_server.go @@ -1,12 +1,14 @@ -// Code generated by mockery v2.38.0. DO NOT EDIT. +// Code generated by mockery v2.40.1. DO NOT EDIT. package wsservermocks import ( + context "context" http "net/http" - wsserver "github.com/hyperledger/firefly-common/pkg/wsserver" mock "github.com/stretchr/testify/mock" + + wsserver "github.com/hyperledger/firefly-common/pkg/wsserver" ) // WebSocketServer is an autogenerated mock type for the WebSocketServer type @@ -14,55 +16,49 @@ type WebSocketServer struct { mock.Mock } +// Broadcast provides a mock function with given fields: ctx, stream, payload +func (_m *WebSocketServer) Broadcast(ctx context.Context, stream string, payload interface{}) { + _m.Called(ctx, stream, payload) +} + // Close provides a mock function with given fields: func (_m *WebSocketServer) Close() { _m.Called() } -// GetChannels provides a mock function with given fields: streamName -func (_m *WebSocketServer) GetChannels(streamName string) (chan<- interface{}, chan<- interface{}, <-chan *wsserver.WebSocketCommandMessageOrError) { - ret := _m.Called(streamName) +// Handler provides a mock function with given fields: w, r +func (_m *WebSocketServer) Handler(w http.ResponseWriter, r *http.Request) { + _m.Called(w, r) +} + +// RoundTrip provides a mock function with given fields: ctx, stream, payload +func (_m *WebSocketServer) RoundTrip(ctx context.Context, stream string, payload wsserver.WSBatch) (*wsserver.WebSocketCommandMessage, error) { + ret := _m.Called(ctx, stream, payload) if len(ret) == 0 { - panic("no return value specified for GetChannels") + panic("no return value specified for RoundTrip") } - var r0 chan<- interface{} - var r1 chan<- interface{} - var r2 <-chan *wsserver.WebSocketCommandMessageOrError - if rf, ok := ret.Get(0).(func(string) (chan<- interface{}, chan<- interface{}, <-chan *wsserver.WebSocketCommandMessageOrError)); ok { - return rf(streamName) + var r0 *wsserver.WebSocketCommandMessage + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, wsserver.WSBatch) (*wsserver.WebSocketCommandMessage, error)); ok { + return rf(ctx, stream, payload) } - if rf, ok := ret.Get(0).(func(string) chan<- interface{}); ok { - r0 = rf(streamName) + if rf, ok := ret.Get(0).(func(context.Context, string, wsserver.WSBatch) *wsserver.WebSocketCommandMessage); ok { + r0 = rf(ctx, stream, payload) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(chan<- interface{}) + r0 = ret.Get(0).(*wsserver.WebSocketCommandMessage) } } - if rf, ok := ret.Get(1).(func(string) chan<- interface{}); ok { - r1 = rf(streamName) + if rf, ok := ret.Get(1).(func(context.Context, string, wsserver.WSBatch) error); ok { + r1 = rf(ctx, stream, payload) } else { - if ret.Get(1) != nil { - r1 = ret.Get(1).(chan<- interface{}) - } - } - - if rf, ok := ret.Get(2).(func(string) <-chan *wsserver.WebSocketCommandMessageOrError); ok { - r2 = rf(streamName) - } else { - if ret.Get(2) != nil { - r2 = ret.Get(2).(<-chan *wsserver.WebSocketCommandMessageOrError) - } + r1 = ret.Error(1) } - return r0, r1, r2 -} - -// Handler provides a mock function with given fields: w, r -func (_m *WebSocketServer) Handler(w http.ResponseWriter, r *http.Request) { - _m.Called(w, r) + return r0, r1 } // NewWebSocketServer creates a new instance of WebSocketServer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. diff --git a/pkg/dbsql/crud.go b/pkg/dbsql/crud.go index 94c7bcc..2c951d3 100644 --- a/pkg/dbsql/crud.go +++ b/pkg/dbsql/crud.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -117,6 +117,7 @@ type CRUD[T Resource] interface { NewFilterBuilder(ctx context.Context) ffapi.FilterBuilder NewUpdateBuilder(ctx context.Context) ffapi.UpdateBuilder GetQueryFactory() ffapi.QueryFactory + TableAlias() string Scoped(scope sq.Eq) CRUD[T] // allows dynamic scoping to a collection } @@ -151,6 +152,13 @@ func (c *CrudBase[T]) Scoped(scope sq.Eq) CRUD[T] { return &cScoped } +func (c *CrudBase[T]) TableAlias() string { + if c.ReadTableAlias != "" { + return c.ReadTableAlias + } + return c.Table +} + func (c *CrudBase[T]) GetQueryFactory() ffapi.QueryFactory { return c.QueryFactory } diff --git a/pkg/dbsql/crud_test.go b/pkg/dbsql/crud_test.go index 8ef8662..a07998f 100644 --- a/pkg/dbsql/crud_test.go +++ b/pkg/dbsql/crud_test.go @@ -656,7 +656,9 @@ func TestLeftJOINExample(t *testing.T) { ctx := context.Background() crudables := newCRUDCollection(sql.db, "ns1") + assert.Equal(t, "crudables", crudables.TableAlias()) linkables := newLinkableCollection(sql.db, "ns1") + assert.Equal(t, "l", linkables.TableAlias()) c1 := &TestCRUDable{ ResourceBase: ResourceBase{ diff --git a/pkg/eventstreams/activestream.go b/pkg/eventstreams/activestream.go index dabfc7c..3bedad3 100644 --- a/pkg/eventstreams/activestream.go +++ b/pkg/eventstreams/activestream.go @@ -28,7 +28,6 @@ import ( ) type eventStreamBatch[DataType any] struct { - number int64 events []*Event[DataType] batchTimer *time.Timer } @@ -37,7 +36,6 @@ type activeStream[CT EventStreamSpec, DT any] struct { *eventStream[CT, DT] ctx context.Context cancelCtx func() - batchNumber int64 filterSkipped int64 EventStreamStatistics eventLoopDone chan struct{} @@ -171,9 +169,7 @@ func (as *activeStream[CT, DT]) runBatchLoop() { as.filterSkipped++ } else { if batch == nil { - as.batchNumber++ batch = &eventStreamBatch[DT]{ - number: as.batchNumber, batchTimer: time.NewTimer(batchTimeout), events: make([]*Event[DT], 0, *esSpec.BatchSize), } @@ -263,7 +259,6 @@ func (as *activeStream[CT, DT]) checkpointRoutine() { // performActionWithRetry performs an action, with exponential back-off retry up // to a given threshold. Only returns error in the case that the context is closed. func (as *activeStream[CT, DT]) dispatchBatch(batch *eventStreamBatch[DT]) (err error) { - as.LastDispatchNumber = batch.number as.LastDispatchTime = fftypes.Now() as.LastDispatchFailure = "" as.LastDispatchAttempts = 0 @@ -273,17 +268,13 @@ func (as *activeStream[CT, DT]) dispatchBatch(batch *eventStreamBatch[DT]) (err for { // Short exponential back-off retry err := as.retry.Do(as.ctx, "action", func(_ int) (retry bool, err error) { - log.L(as.ctx).Debugf("Batch %d attempt %d dispatching. Len=%d", - batch.number, as.LastDispatchAttempts, len(batch.events)) + log.L(as.ctx).Debugf("Batch attempt %d dispatching. Len=%d", as.LastDispatchAttempts, len(batch.events)) err = as.action.AttemptDispatch(as.ctx, as.LastDispatchAttempts, &EventBatch[DT]{ - Type: MessageTypeEventBatch, - StreamID: as.spec.GetID(), - BatchNumber: batch.number, - Events: batch.events, + Type: MessageTypeEventBatch, + Events: batch.events, }) if err != nil { - log.L(as.ctx).Errorf("Batch %d attempt %d failed. err=%s", - batch.number, as.LastDispatchAttempts, err) + log.L(as.ctx).Errorf("Batch attempt %d failed. err=%s", as.LastDispatchAttempts, err) as.LastDispatchAttempts++ as.LastDispatchFailure = err.Error() as.LastDispatchStatus = DispatchStatusRetrying diff --git a/pkg/eventstreams/e2e_test.go b/pkg/eventstreams/e2e_test.go index c0601d1..cef28d3 100644 --- a/pkg/eventstreams/e2e_test.go +++ b/pkg/eventstreams/e2e_test.go @@ -146,6 +146,58 @@ func TestE2E_DeliveryWebSockets(t *testing.T) { assert.Equal(t, 1, ts.startCount) } +func TestE2E_DeliveryWebSocketsBroadcast(t *testing.T) { + ctx, p, wss, wsc, done := setupE2ETest(t) + + ts := &testSource{started: make(chan struct{})} + close(ts.started) // start delivery immediately - will block as no WS connected + + mgr, err := NewEventStreamManager[*GenericEventStream, testData](ctx, GenerateConfig[*GenericEventStream, testData](ctx), p, wss, ts) + assert.NoError(t, err) + + // Create a stream to sub-select one topic + es1 := &GenericEventStream{ + Type: &EventStreamTypeWebSocket, + EventStreamSpecFields: EventStreamSpecFields{ + TopicFilter: ptrTo("topic_1"), // only one of the topics + BatchSize: ptrTo(10), + }, + WebSocket: &WebSocketConfig{ + DistributionMode: &DistributionModeBroadcast, + }, + } + created, err := mgr.UpsertStream(ctx, "stream1", es1) + assert.NoError(t, err) + assert.True(t, created) + + // Connect our websocket and start it + err = wsc.Connect() + assert.NoError(t, err) + err = wsc.Send(ctx, []byte(`{"type":"start","stream":"stream1"}`)) + assert.NoError(t, err) + + expectedNumber := 1 + for i := 0; i < 10; i++ { + data := <-wsc.Receive() + var batch EventBatch[testData] + err := json.Unmarshal(data, &batch) + assert.NoError(t, err) + // each batch should be 10 + assert.Len(t, batch.Events, 10) + for _, e := range batch.Events { + assert.Equal(t, "topic_1", e.Topic) + assert.Equal(t, expectedNumber, e.Data.Field1) + // messages are published 0,1,2 over 10 topics, and we're only getting one of those topics + expectedNumber += 10 + } + } + + // Check we ran the loop just once, and from the empty string for the checkpoint (as there was no InitialSequenceID) + done() + assert.Equal(t, "", ts.sequenceStartedWith) + assert.Equal(t, 1, ts.startCount) +} + func TestE2E_DeliveryWebSocketsNack(t *testing.T) { ctx, p, wss, wsc, done := setupE2ETest(t, func() { RetrySection.Set(retry.ConfigMaximumDelay, "1ms" /* spin quickly */) @@ -254,6 +306,54 @@ func TestE2E_WebsocketDeliveryRestartReset(t *testing.T) { } +func TestE2E_ResetStreamWhileAwaitingAck(t *testing.T) { + ctx, p, wss, wsc, done := setupE2ETest(t) + + ts := &testSource{started: make(chan struct{})} + close(ts.started) // start delivery immediately - will block as no WS connected + + mgr, err := NewEventStreamManager[*GenericEventStream, testData](ctx, GenerateConfig[*GenericEventStream, testData](ctx), p, wss, ts) + assert.NoError(t, err) + + // Create a stream to sub-select one topic + es1 := &GenericEventStream{ + Type: &EventStreamTypeWebSocket, + EventStreamSpecFields: EventStreamSpecFields{ + TopicFilter: ptrTo("topic_1"), // only one of the topics + BatchSize: ptrTo(10), + }, + } + created, err := mgr.UpsertStream(ctx, "stream1", es1) + assert.NoError(t, err) + assert.True(t, created) + + // Connect our websocket and start it + err = wsc.Connect() + assert.NoError(t, err) + err = wsc.Send(ctx, []byte(`{"type":"start","stream":"stream1"}`)) + assert.NoError(t, err) + + // Receive the message batch + data := <-wsc.Receive() + var batch EventBatch[testData] + err = json.Unmarshal(data, &batch) + assert.NoError(t, err) + + // Do a reset before we ack. + err = mgr.ResetStream(ctx, "stream1", ptrTo("12345")) + assert.NoError(t, err) + + // Should get the batch again + data = <-wsc.Receive() + err = json.Unmarshal(data, &batch) + assert.NoError(t, err) + + // Check we did the reset + done() + assert.Equal(t, "12345", ts.sequenceStartedWith) + assert.Equal(t, 2, ts.startCount) +} + func TestE2E_DeliveryWebHooks200(t *testing.T) { ctx, p, wss, wsc, done := setupE2ETest(t) defer done() @@ -499,7 +599,7 @@ func wsReceiveNack(ctx context.Context, t *testing.T, wsc wsclient.WSClient, cb assert.NoError(t, err) } -func setupE2ETest(t *testing.T, extraSetup ...func()) (context.Context, Persistence[*GenericEventStream], wsserver.WebSocketChannels, wsclient.WSClient, func()) { +func setupE2ETest(t *testing.T, extraSetup ...func()) (context.Context, Persistence[*GenericEventStream], wsserver.Protocol, wsclient.WSClient, func()) { logrus.SetLevel(logrus.TraceLevel) ctx := context.Background() @@ -507,6 +607,8 @@ func setupE2ETest(t *testing.T, extraSetup ...func()) (context.Context, Persiste conf := config.RootSection("ut") dbConf := conf.SubSection("db") esConf := conf.SubSection("eventstreams") + wsServerConf := conf.SubSection("wss") + wsserver.InitConfig(wsServerConf) dbsql.InitSQLiteConfig(dbConf) InitConfig(esConf) @@ -529,7 +631,7 @@ func setupE2ETest(t *testing.T, extraSetup ...func()) (context.Context, Persiste p.EventStreams().Validate() p.Checkpoints().Validate() - wss := wsserver.NewWebSocketServer(ctx) + wss := wsserver.NewWebSocketServer(ctx, wsserver.GenerateConfig(wsServerConf)) server := httptest.NewServer(http.HandlerFunc(wss.Handler)) // Build the WS connection, but don't connect it yet diff --git a/pkg/eventstreams/event.go b/pkg/eventstreams/event.go index d160e0b..f2cb811 100644 --- a/pkg/eventstreams/event.go +++ b/pkg/eventstreams/event.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -18,15 +18,20 @@ package eventstreams import ( "encoding/json" + + "github.com/hyperledger/firefly-common/pkg/wsserver" ) const MessageTypeEventBatch = "event_batch" type EventBatch[DataType any] struct { - Type string `json:"type"` // always MessageTypeEventBatch (for consistent WebSocket flow control) - StreamID string `json:"stream"` // the ID of the event stream for this event - BatchNumber int64 `json:"batchNumber"` // should be provided back in the ack - Events []*Event[DataType] `json:"events"` // an array of events allows efficient batch acknowledgment + wsserver.BatchHeader + Type string `json:"type"` // always MessageTypeEventBatch (for consistent WebSocket flow control) + Events []*Event[DataType] `json:"events"` // an array of events allows efficient batch acknowledgment +} + +func (eb *EventBatch[DataType]) GetBatchHeader() *wsserver.BatchHeader { + return &eb.BatchHeader } type Event[DataType any] struct { diff --git a/pkg/eventstreams/eventstreams.go b/pkg/eventstreams/eventstreams.go index 105731d..443d627 100644 --- a/pkg/eventstreams/eventstreams.go +++ b/pkg/eventstreams/eventstreams.go @@ -106,7 +106,6 @@ type EventStreamSpecFields struct { type EventStreamStatistics struct { StartTime *fftypes.FFTime `ffstruct:"EventStreamStatistics" json:"startTime"` LastDispatchTime *fftypes.FFTime `ffstruct:"EventStreamStatistics" json:"lastDispatchTime"` - LastDispatchNumber int64 `ffstruct:"EventStreamStatistics" json:"lastDispatchBatch"` LastDispatchAttempts int `ffstruct:"EventStreamStatistics" json:"lastDispatchAttempts,omitempty"` LastDispatchFailure string `ffstruct:"EventStreamStatistics" json:"lastDispatchFailure,omitempty"` LastDispatchStatus DispatchStatus `ffstruct:"EventStreamStatistics" json:"lastDispatchComplete"` diff --git a/pkg/eventstreams/manager.go b/pkg/eventstreams/manager.go index 350a33c..96cc2d5 100644 --- a/pkg/eventstreams/manager.go +++ b/pkg/eventstreams/manager.go @@ -80,13 +80,13 @@ type esManager[CT EventStreamSpec, DT any] struct { mux sync.Mutex streams map[string]*eventStream[CT, DT] tlsConfigs map[string]*tls.Config - wsChannels wsserver.WebSocketChannels + wsProtocol wsserver.Protocol persistence Persistence[CT] runtime Runtime[CT, DT] dispatchers map[EventStreamType]DispatcherFactory[CT, DT] } -func NewEventStreamManager[CT EventStreamSpec, DT any](ctx context.Context, config *Config[CT, DT], p Persistence[CT], wsChannels wsserver.WebSocketChannels, source Runtime[CT, DT]) (es Manager[CT], err error) { +func NewEventStreamManager[CT EventStreamSpec, DT any](ctx context.Context, config *Config[CT, DT], p Persistence[CT], wsProtocol wsserver.Protocol, source Runtime[CT, DT]) (es Manager[CT], err error) { if config.Retry == nil { return nil, i18n.NewError(ctx, i18n.MsgESConfigNotInitialized) @@ -106,7 +106,7 @@ func NewEventStreamManager[CT EventStreamSpec, DT any](ctx context.Context, conf tlsConfigs: tlsConfigs, runtime: source, persistence: p, - wsChannels: wsChannels, + wsProtocol: wsProtocol, streams: map[string]*eventStream[CT, DT]{}, dispatchers: config.AdditionalDispatchers, } diff --git a/pkg/eventstreams/webhooks_test.go b/pkg/eventstreams/webhooks_test.go index 8616217..f070c71 100644 --- a/pkg/eventstreams/webhooks_test.go +++ b/pkg/eventstreams/webhooks_test.go @@ -28,7 +28,7 @@ import ( "github.com/hyperledger/firefly-common/pkg/ffapi" "github.com/hyperledger/firefly-common/pkg/ffresty" "github.com/hyperledger/firefly-common/pkg/fftls" - "github.com/hyperledger/firefly-common/pkg/fftypes" + "github.com/hyperledger/firefly-common/pkg/wsserver" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -82,8 +82,10 @@ func TestWebhooksBadHost(t *testing.T) { wh := newTestWebhooks(t, &WebhookConfig{URL: &u}) err := wh.AttemptDispatch(context.Background(), 0, &EventBatch[testData]{ - StreamID: fftypes.NewUUID().String(), - BatchNumber: 1, + BatchHeader: wsserver.BatchHeader{ + BatchNumber: 1, + Stream: "stream1", + }, Events: []*Event[testData]{ {Data: &testData{Field1: 12345}}, }, @@ -98,8 +100,10 @@ func TestWebhooksPrivateBlocked(t *testing.T) { }) err := wh.AttemptDispatch(context.Background(), 0, &EventBatch[testData]{ - StreamID: fftypes.NewUUID().String(), - BatchNumber: 1, + BatchHeader: wsserver.BatchHeader{ + BatchNumber: 1, + Stream: "stream1", + }, Events: []*Event[testData]{ {Data: &testData{Field1: 12345}}, }, @@ -132,8 +136,10 @@ func TestWebhooksCustomHeaders403(t *testing.T) { done := make(chan struct{}) go func() { err := wh.AttemptDispatch(context.Background(), 0, &EventBatch[testData]{ - StreamID: fftypes.NewUUID().String(), - BatchNumber: 1, + BatchHeader: wsserver.BatchHeader{ + BatchNumber: 1, + Stream: "stream1", + }, Events: []*Event[testData]{ {Data: &testData{Field1: 12345}}, }, @@ -155,8 +161,10 @@ func TestWebhooksCustomHeadersConnectFail(t *testing.T) { done := make(chan struct{}) go func() { err := wh.AttemptDispatch(context.Background(), 0, &EventBatch[testData]{ - StreamID: fftypes.NewUUID().String(), - BatchNumber: 1, + BatchHeader: wsserver.BatchHeader{ + BatchNumber: 1, + Stream: "stream1", + }, Events: []*Event[testData]{ {Data: &testData{Field1: 12345}}, }, @@ -191,8 +199,10 @@ func TestWebhooksTLS(t *testing.T) { done := make(chan struct{}) go func() { err := wh.AttemptDispatch(context.Background(), 0, &EventBatch[testData]{ - StreamID: fftypes.NewUUID().String(), - BatchNumber: 1, + BatchHeader: wsserver.BatchHeader{ + BatchNumber: 1, + Stream: "stream1", + }, Events: []*Event[testData]{ {Data: &testData{Field1: 12345}}, }, diff --git a/pkg/eventstreams/websockets.go b/pkg/eventstreams/websockets.go index 8b00cc5..e4ba31c 100644 --- a/pkg/eventstreams/websockets.go +++ b/pkg/eventstreams/websockets.go @@ -22,7 +22,6 @@ import ( "database/sql/driver" "github.com/hyperledger/firefly-common/pkg/fftypes" - "github.com/hyperledger/firefly-common/pkg/i18n" "github.com/hyperledger/firefly-common/pkg/log" "github.com/hyperledger/firefly-common/pkg/wsserver" ) @@ -64,13 +63,13 @@ func (wsf *webSocketDispatcherFactory[CT, DT]) Validate(ctx context.Context, con type webSocketAction[DT any] struct { topic string spec *WebSocketConfig - wsChannels wsserver.WebSocketChannels + wsProtocol wsserver.Protocol } func (wsf *webSocketDispatcherFactory[CT, DT]) NewDispatcher(_ context.Context, _ *Config[CT, DT], spec CT) Dispatcher[DT] { return &webSocketAction[DT]{ spec: spec.WebSocketConf(), - wsChannels: wsf.esm.wsChannels, + wsProtocol: wsf.esm.wsProtocol, topic: *spec.ESFields().Name, } } @@ -78,30 +77,14 @@ func (wsf *webSocketDispatcherFactory[CT, DT]) NewDispatcher(_ context.Context, func (w *webSocketAction[DT]) AttemptDispatch(ctx context.Context, attempt int, batch *EventBatch[DT]) error { var err error - // Get a blocking channel to send and receive on our chosen namespace - sender, broadcaster, receiver := w.wsChannels.GetChannels(w.topic) - - var channel chan<- interface{} isBroadcast := *w.spec.DistributionMode == DistributionModeBroadcast - if isBroadcast { - channel = broadcaster - } else { - channel = sender - } - - // Send the batch of events - select { - case channel <- batch: - break - case <-ctx.Done(): - err = i18n.NewError(ctx, i18n.MsgWebSocketInterruptedSend) - } - if err == nil && !isBroadcast { - log.L(ctx).Infof("Batch %d dispatched (len=%d,attempt=%d)", batch.BatchNumber, len(batch.Events), attempt) - err = w.waitForAck(ctx, receiver, batch.BatchNumber) + if isBroadcast { + w.wsProtocol.Broadcast(ctx, w.topic, batch) + return nil } + _, err = w.wsProtocol.RoundTrip(ctx, w.topic, batch) // Pass back any exception due if err != nil { log.L(ctx).Infof("WebSocket event batch %d delivery failed (len=%d,attempt=%d): %s", batch.BatchNumber, len(batch.Events), attempt, err) @@ -110,24 +93,3 @@ func (w *webSocketAction[DT]) AttemptDispatch(ctx context.Context, attempt int, log.L(ctx).Infof("WebSocket event batch %d complete (len=%d,attempt=%d)", batch.BatchNumber, len(batch.Events), attempt) return nil } - -func (w *webSocketAction[DT]) waitForAck(ctx context.Context, receiver <-chan *wsserver.WebSocketCommandMessageOrError, batchNumber int64) error { - // Wait for the next ack or exception - for { - select { - case msgOrErr := <-receiver: - if msgOrErr.Err != nil { - // If we get an error, we have to assume the other side did not receive this batch, and send it again - return msgOrErr.Err - } - if msgOrErr.Msg.BatchNumber != batchNumber { - log.L(ctx).Infof("Discarding ack for batch %d (awaiting %d)", msgOrErr.Msg.BatchNumber, batchNumber) - continue - } - log.L(ctx).Infof("Batch %d acknowledged", batchNumber) - return nil - case <-ctx.Done(): - return i18n.NewError(ctx, i18n.MsgWebSocketInterruptedReceive) - } - } -} diff --git a/pkg/eventstreams/websockets_test.go b/pkg/eventstreams/websockets_test.go index f03cd02..0ea2c2c 100644 --- a/pkg/eventstreams/websockets_test.go +++ b/pkg/eventstreams/websockets_test.go @@ -15,151 +15,3 @@ // limitations under the License. package eventstreams - -import ( - "context" - "fmt" - "testing" - - "github.com/hyperledger/firefly-common/mocks/wsservermocks" - "github.com/hyperledger/firefly-common/pkg/ffapi" - "github.com/hyperledger/firefly-common/pkg/fftypes" - "github.com/hyperledger/firefly-common/pkg/wsserver" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func mockWSChannels(wsc *wsservermocks.WebSocketChannels) (chan interface{}, chan interface{}, chan *wsserver.WebSocketCommandMessageOrError) { - senderChannel := make(chan interface{}, 1) - broadcastChannel := make(chan interface{}, 1) - receiverChannel := make(chan *wsserver.WebSocketCommandMessageOrError, 1) - wsc.On("GetChannels", "ut_stream").Return((chan<- interface{})(senderChannel), (chan<- interface{})(broadcastChannel), (<-chan *wsserver.WebSocketCommandMessageOrError)(receiverChannel)).Maybe() - return senderChannel, broadcastChannel, receiverChannel -} - -func newTestWebSocketsFactory(t *testing.T) (context.Context, *esManager[*GenericEventStream, testData], *wsservermocks.WebSocketChannels, *webSocketDispatcherFactory[*GenericEventStream, testData]) { - ctx, mgr, _, done := newMockESManager(t, func(mdb *mockPersistence) { - mdb.eventStreams.On("GetMany", mock.Anything, mock.Anything).Return([]*GenericEventStream{}, &ffapi.FilterResult{}, nil) - }) - done() - - mws := wsservermocks.NewWebSocketChannels(t) - mgr.wsChannels = mws - - return ctx, mgr, mws, &webSocketDispatcherFactory[*GenericEventStream, testData]{esm: mgr} -} - -func TestWSAttemptIgnoreWrongAcks(t *testing.T) { - - ctx, mgr, mws, whf := newTestWebSocketsFactory(t) - _, _, rc := mockWSChannels(mws) - - go func() { - rc <- &wsserver.WebSocketCommandMessageOrError{Msg: &wsserver.WebSocketCommandMessage{ - BatchNumber: 12345, - }} - rc <- &wsserver.WebSocketCommandMessageOrError{Msg: &wsserver.WebSocketCommandMessage{ - BatchNumber: 23456, - }} - }() - - dmw := DistributionModeBroadcast - spec := &GenericEventStream{ - EventStreamSpecFields: EventStreamSpecFields{ - Name: ptrTo("ut_stream"), - }, - WebSocket: &WebSocketConfig{ - DistributionMode: &dmw, - }, - } - wsa := whf.NewDispatcher(ctx, &mgr.config, spec).(*webSocketAction[testData]) - - err := wsa.AttemptDispatch(context.Background(), 0, &EventBatch[testData]{ - StreamID: fftypes.NewUUID().String(), - BatchNumber: 1, - Events: []*Event[testData]{ - {Data: &testData{Field1: 12345}}, - }, - }) - assert.NoError(t, err) - - err = wsa.waitForAck(context.Background(), rc, 23456) - assert.NoError(t, err) -} - -func TestWSattemptDispatchExitPushingEvent(t *testing.T) { - - ctx, mgr, mws, whf := newTestWebSocketsFactory(t) - _, bc, _ := mockWSChannels(mws) - bc <- []*fftypes.JSONAny{} // block the broadcast channel - - dmw := DistributionModeBroadcast - spec := &GenericEventStream{ - EventStreamSpecFields: EventStreamSpecFields{ - Name: ptrTo("ut_stream"), - }, - WebSocket: &WebSocketConfig{ - DistributionMode: &dmw, - }, - } - wsa := whf.NewDispatcher(ctx, &mgr.config, spec).(*webSocketAction[testData]) - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - err := wsa.AttemptDispatch(ctx, 0, &EventBatch[testData]{ - StreamID: fftypes.NewUUID().String(), - BatchNumber: 1, - Events: []*Event[testData]{ - {Data: &testData{Field1: 12345}}, - }, - }) - assert.Regexp(t, "FF00225", err) - -} - -func TestWSattemptDispatchExitReceivingReply(t *testing.T) { - - ctx, mgr, mws, whf := newTestWebSocketsFactory(t) - _, _, rc := mockWSChannels(mws) - - dmw := DistributionModeBroadcast - spec := &GenericEventStream{ - EventStreamSpecFields: EventStreamSpecFields{ - Name: ptrTo("ut_stream"), - }, - WebSocket: &WebSocketConfig{ - DistributionMode: &dmw, - }, - } - wsa := whf.NewDispatcher(ctx, &mgr.config, spec).(*webSocketAction[testData]) - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - err := wsa.waitForAck(ctx, rc, -1) - assert.Regexp(t, "FF00226", err) - -} - -func TestWSattemptDispatchNackFromClient(t *testing.T) { - - ctx, mgr, mws, whf := newTestWebSocketsFactory(t) - _, _, rc := mockWSChannels(mws) - rc <- &wsserver.WebSocketCommandMessageOrError{ - Err: fmt.Errorf("pop"), - } - - dmw := DistributionModeBroadcast - spec := &GenericEventStream{ - EventStreamSpecFields: EventStreamSpecFields{ - Name: ptrTo("ut_stream"), - }, - WebSocket: &WebSocketConfig{ - DistributionMode: &dmw, - }, - } - wsa := whf.NewDispatcher(ctx, &mgr.config, spec).(*webSocketAction[testData]) - - err := wsa.waitForAck(context.Background(), rc, -1) - assert.Regexp(t, "pop", err) - -} diff --git a/pkg/ffapi/apiserver.go b/pkg/ffapi/apiserver.go index 2d43f02..d1d639a 100644 --- a/pkg/ffapi/apiserver.go +++ b/pkg/ffapi/apiserver.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -19,6 +19,7 @@ package ffapi import ( "context" "fmt" + "io" "net" "net/http" "time" @@ -86,6 +87,7 @@ type APIServerOptions[T any] struct { type APIServerRouteExt[T any] struct { JSONHandler func(*APIRequest, T) (output interface{}, err error) UploadHandler func(*APIRequest, T) (output interface{}, err error) + StreamHandler func(*APIRequest, T) (output io.ReadCloser, err error) } // NewAPIServer makes a new server, with the specified configuration, and @@ -201,13 +203,25 @@ func (as *apiServer[T]) routeHandler(hf *HandlerFactory, route *Route) http.Hand // We extend the base ffapi functionality, with standardized DB filter support for all core resources. // We also pass the Orchestrator context through ext := route.Extensions.(*APIServerRouteExt[T]) - route.JSONHandler = func(r *APIRequest) (output interface{}, err error) { - er, err := as.EnrichRequest(r) - if err != nil { - return nil, err + switch { + case ext.StreamHandler != nil: + route.StreamHandler = func(r *APIRequest) (output io.ReadCloser, err error) { + er, err := as.EnrichRequest(r) + if err != nil { + return nil, err + } + return ext.StreamHandler(r, er) + } + case ext.JSONHandler != nil: + route.JSONHandler = func(r *APIRequest) (output interface{}, err error) { + er, err := as.EnrichRequest(r) + if err != nil { + return nil, err + } + return ext.JSONHandler(r, er) } - return ext.JSONHandler(r, er) } + return hf.RouteHandler(route) } @@ -247,7 +261,7 @@ func (as *apiServer[T]) createMuxRouter(ctx context.Context) *mux.Router { return ce.UploadHandler(r, er) } } - if ce.JSONHandler != nil || ce.UploadHandler != nil { + if ce.JSONHandler != nil || ce.UploadHandler != nil || ce.StreamHandler != nil { r.HandleFunc(fmt.Sprintf("/api/v1/%s", route.Path), as.routeHandler(hf, route)). Methods(route.Method) } diff --git a/pkg/ffapi/apiserver_test.go b/pkg/ffapi/apiserver_test.go index 8abad5d..f91b5bd 100644 --- a/pkg/ffapi/apiserver_test.go +++ b/pkg/ffapi/apiserver_test.go @@ -19,6 +19,7 @@ package ffapi import ( "context" "fmt" + "github.com/getkin/kin-openapi/openapi3" "io" "net/http" "strings" @@ -38,6 +39,7 @@ type utManager struct { mockEnrichErr error calledJSONHandler string calledUploadHandler string + calledStreamHandler string } type sampleInput struct { @@ -80,6 +82,35 @@ var utAPIRoute1 = &Route{ }, } +var utAPIRoute2 = &Route{ + Name: "utAPIRoute2", + Path: "ut/utresource/{resourceid}/getit", + Method: http.MethodGet, + Description: "random GET stream route for testing", + PathParams: []*PathParam{ + {Name: "resourceid", Description: "My resource"}, + }, + FormParams: nil, + JSONInputValue: nil, + JSONOutputValue: nil, + JSONOutputCodes: nil, + CustomResponseRefs: map[string]*openapi3.ResponseRef{ + "200": { + Value: &openapi3.Response{ + Content: openapi3.Content{ + "application/octet-stream": {}, + }, + }, + }, + }, + Extensions: &APIServerRouteExt[*utManager]{ + StreamHandler: func(r *APIRequest, um *utManager) (output io.ReadCloser, err error) { + um.calledStreamHandler = r.PP["resourceid"] + return io.NopCloser(strings.NewReader("a stream!")), nil + }, + }, +} + func initUTConfig() (config.Section, config.Section, config.Section) { config.RootConfigReset() apiConfig := config.RootSection("ut.api") @@ -97,7 +128,7 @@ func newTestAPIServer(t *testing.T, start bool) (*utManager, *apiServer[*utManag um := &utManager{t: t} as := NewAPIServer(ctx, APIServerOptions[*utManager]{ MetricsRegistry: metric.NewPrometheusMetricsRegistry("ut"), - Routes: []*Route{utAPIRoute1}, + Routes: []*Route{utAPIRoute1, utAPIRoute2}, EnrichRequest: func(r *APIRequest) (*utManager, error) { // This could be some dynamic object based on extra processing in the request, // but the most common case is you just have a "manager" that you inject into each @@ -125,6 +156,24 @@ func newTestAPIServer(t *testing.T, start bool) (*utManager, *apiServer[*utManag } } +func TestAPIServerInvokeAPIRouteStream(t *testing.T) { + um, as, done := newTestAPIServer(t, true) + defer done() + + <-as.Started() + + var o sampleOutput + res, err := resty.New().R(). + SetBody(nil). + SetResult(&o). + Get(fmt.Sprintf("%s/api/v1/ut/utresource/id12345/getit", as.APIPublicURL())) + assert.NoError(t, err) + assert.Equal(t, 200, res.StatusCode()) + assert.Equal(t, "application/octet-stream", res.Header().Get("Content-Type")) + assert.Equal(t, "id12345", um.calledStreamHandler) + assert.Equal(t, "a stream!", string(res.Body())) +} + func TestAPIServerInvokeAPIRouteJSON(t *testing.T) { um, as, done := newTestAPIServer(t, true) defer done() diff --git a/pkg/ffapi/filter.go b/pkg/ffapi/filter.go index d37050e..16ab200 100644 --- a/pkg/ffapi/filter.go +++ b/pkg/ffapi/filter.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -63,6 +63,30 @@ type Filter interface { // Builder returns the builder that made it Builder() FilterBuilder + + // Assert that this is a value filter, not an and/or + ValueFilter() ValueFilter +} + +// ValueFilter is accessor functions for non-Or/And filters - for advanced traversal of the un-finalized tree +type ValueFilter interface { + // The operation for this filter + Op() FilterOp + + // The field name + Field() string + + // The value + Value() interface{} + + // Set the operation for this filter + SetOp(op FilterOp) + + // Set the field name + SetField(f string) + + // Set the value + SetValue(v interface{}) } // MultiConditionFilter gives convenience methods to add conditions @@ -368,6 +392,34 @@ func fieldMods(f Field) []FieldMod { return nil } +func (f *baseFilter) ValueFilter() ValueFilter { + return f +} + +func (f *baseFilter) Op() FilterOp { + return f.op +} + +func (f *baseFilter) Field() string { + return f.field +} + +func (f *baseFilter) Value() interface{} { + return f.value +} + +func (f *baseFilter) SetOp(op FilterOp) { + f.op = op +} + +func (f *baseFilter) SetField(field string) { + f.field = field +} + +func (f *baseFilter) SetValue(value interface{}) { + f.value = value +} + func (f *baseFilter) Finalize() (fi *FilterInfo, err error) { var children []*FilterInfo var value FieldSerialization diff --git a/pkg/ffapi/filter_test.go b/pkg/ffapi/filter_test.go index 42f48e2..cbd52d4 100644 --- a/pkg/ffapi/filter_test.go +++ b/pkg/ffapi/filter_test.go @@ -357,3 +357,19 @@ func TestStringsForTypes(t *testing.T) { assert.Equal(t, "t1,t2", (&ffNameArrayField{na: fftypes.FFStringArray{"t1", "t2"}}).String()) assert.Equal(t, "true", (&boolField{b: true}).String()) } + +func TestValueFilterAccess(t *testing.T) { + fb := TestQueryFactory.NewFilter(context.Background()).Gt("sequence", 0) + assert.NotNil(t, fb.Builder()) + assert.Equal(t, FilterOpGt, fb.ValueFilter().Op()) + assert.Equal(t, "sequence", fb.ValueFilter().Field()) + assert.Equal(t, 0, fb.ValueFilter().Value()) + + fb.ValueFilter().SetOp(FilterOpGte) + fb.ValueFilter().SetField("seq") + fb.ValueFilter().SetValue(12345) + + assert.Equal(t, FilterOpGte, fb.ValueFilter().Op()) + assert.Equal(t, "seq", fb.ValueFilter().Field()) + assert.Equal(t, 12345, fb.ValueFilter().Value()) +} diff --git a/pkg/ffapi/handler.go b/pkg/ffapi/handler.go index 4a7229b..40e5487 100644 --- a/pkg/ffapi/handler.go +++ b/pkg/ffapi/handler.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -189,7 +189,7 @@ func (hs *HandlerFactory) RouteHandler(route *Route) http.HandlerFunc { } } - var status = 400 // if fail parsing input + status := 400 // if fail parsing input var output interface{} if err == nil { queryParams, pathParams, queryArrayParams = hs.getParams(req, route) @@ -202,24 +202,29 @@ func (hs *HandlerFactory) RouteHandler(route *Route) http.HandlerFunc { if err == nil { r := &APIRequest{ - Req: req, - PP: pathParams, - QP: queryParams, - QAP: queryArrayParams, - Filter: filter, - Input: jsonInput, - SuccessStatus: http.StatusOK, + Req: req, + PP: pathParams, + QP: queryParams, + QAP: queryArrayParams, + Filter: filter, + Input: jsonInput, + SuccessStatus: http.StatusOK, + AlwaysPaginate: hs.AlwaysPaginate, + + // res.Header() returns a map which is a ref type so handler header edits are persisted ResponseHeaders: res.Header(), - AlwaysPaginate: hs.AlwaysPaginate, } if len(route.JSONOutputCodes) > 0 { r.SuccessStatus = route.JSONOutputCodes[0] } - if multipart != nil { + switch { + case multipart != nil: r.FP = multipart.formParams r.Part = multipart.part output, err = route.FormUploadHandler(r) - } else { + case route.StreamHandler != nil: + output, err = route.StreamHandler(r) + default: output, err = route.JSONHandler(r) } status = r.SuccessStatus // Can be updated by the route @@ -259,7 +264,9 @@ func (hs *HandlerFactory) handleOutput(ctx context.Context, res http.ResponseWri res.WriteHeader(204) case reader != nil: defer reader.Close() - res.Header().Add("Content-Type", "application/octet-stream") + if res.Header().Get("Content-Type") == "" { + res.Header().Add("Content-Type", "application/octet-stream") + } res.WriteHeader(status) _, marshalErr = io.Copy(res, reader) default: diff --git a/pkg/ffapi/handler_test.go b/pkg/ffapi/handler_test.go index 3f8375b..b8ffd21 100644 --- a/pkg/ffapi/handler_test.go +++ b/pkg/ffapi/handler_test.go @@ -21,6 +21,8 @@ import ( "context" "encoding/json" "fmt" + "github.com/getkin/kin-openapi/openapi3" + "github.com/stretchr/testify/require" "io" "mime/multipart" "net/http" @@ -156,6 +158,72 @@ func TestJSONHTTPNilResponseNon204(t *testing.T) { assert.Regexp(t, "FF00164", resJSON["error"]) } +func TestStreamHttpResponsePlainText200(t *testing.T) { + text := ` +some stream +of +text +!!! +` + s, _, done := newTestServer(t, []*Route{{ + Name: "testRoute", + Path: "/test", + Method: "GET", + CustomResponseRefs: map[string]*openapi3.ResponseRef{ + "200": { + Value: &openapi3.Response{ + Content: openapi3.Content{ + "text/plain": {}, + }, + }, + }, + }, + StreamHandler: func(r *APIRequest) (output io.ReadCloser, err error) { + r.ResponseHeaders.Add("Content-Type", "text/plain") + return io.NopCloser(strings.NewReader(text)), nil + }, + }}, "", nil) + defer done() + + res, err := http.Get(fmt.Sprintf("http://%s/test", s.Addr())) + require.NoError(t, err) + assert.Equal(t, 200, res.StatusCode) + assert.Equal(t, "text/plain", res.Header.Get("Content-Type")) + b, err := io.ReadAll(res.Body) + require.NoError(t, err) + assert.Equal(t, text, string(b)) +} + +func TestStreamHttpResponseBinary200(t *testing.T) { + randomBytes := []byte{3, 255, 192, 201, 33, 50} + s, _, done := newTestServer(t, []*Route{{ + Name: "testRoute", + Path: "/test", + Method: "GET", + CustomResponseRefs: map[string]*openapi3.ResponseRef{ + "200": { + Value: &openapi3.Response{ + Content: openapi3.Content{ + "application/octet-stream": &openapi3.MediaType{}, + }, + }, + }, + }, + StreamHandler: func(r *APIRequest) (output io.ReadCloser, err error) { + return io.NopCloser(bytes.NewReader(randomBytes)), nil + }, + }}, "", nil) + defer done() + + res, err := http.Get(fmt.Sprintf("http://%s/test", s.Addr())) + require.NoError(t, err) + assert.Equal(t, 200, res.StatusCode) + assert.Equal(t, "application/octet-stream", res.Header.Get("Content-Type")) + b, err := io.ReadAll(res.Body) + require.NoError(t, err) + assert.Equal(t, randomBytes, b) +} + func TestJSONHTTPDefault500Error(t *testing.T) { s, _, done := newTestServer(t, []*Route{{ Name: "testRoute", diff --git a/pkg/ffapi/openapi3.go b/pkg/ffapi/openapi3.go index c67868b..2bcee7f 100644 --- a/pkg/ffapi/openapi3.go +++ b/pkg/ffapi/openapi3.go @@ -313,6 +313,12 @@ func (sg *SwaggerGen) addOutput(ctx context.Context, doc *openapi3.T, route *Rou }, }) } + for code, res := range route.CustomResponseRefs { + if res.Value != nil && res.Value.Description == nil { + res.Value.Description = &s + } + op.Responses.Set(code, res) + } } func (sg *SwaggerGen) AddParam(ctx context.Context, op *openapi3.Operation, in, name, def, example string, description i18n.MessageKey, deprecated bool, msgArgs ...interface{}) { diff --git a/pkg/ffapi/openapi3_test.go b/pkg/ffapi/openapi3_test.go index ee1eb98..f44911a 100644 --- a/pkg/ffapi/openapi3_test.go +++ b/pkg/ffapi/openapi3_test.go @@ -19,6 +19,7 @@ package ffapi import ( "context" "fmt" + "github.com/stretchr/testify/require" "net/http" "testing" @@ -298,6 +299,36 @@ func TestFFExcludeTag(t *testing.T) { assert.Regexp(t, "no schema", err) } +func TestCustomResponseRefs(t *testing.T) { + routes := []*Route{ + { + Name: "CustomResponseRefTest", + Path: "/test", + Method: http.MethodGet, + CustomResponseRefs: map[string]*openapi3.ResponseRef{ + "200": { + Value: &openapi3.Response{ + Content: openapi3.Content{ + "text/plain": &openapi3.MediaType{}, + }, + }, + }, + }, + }, + } + swagger := NewSwaggerGen(&SwaggerGenOptions{ + Title: "UnitTest", + Version: "1.0", + BaseURL: "http://localhost:12345/api/v1", + }).Generate(context.Background(), routes) + assert.Nil(t, swagger.Paths.Find("/test").Get.RequestBody) + require.NotEmpty(t, swagger.Paths.Find("/test").Get.Responses) + require.NotNil(t, swagger.Paths.Find("/test").Get.Responses.Value("200")) + require.NotNil(t, swagger.Paths.Find("/test").Get.Responses.Value("200").Value) + assert.NotNil(t, swagger.Paths.Find("/test").Get.Responses.Value("200").Value.Content.Get("text/plain")) + assert.Nil(t, swagger.Paths.Find("/test").Get.Responses.Value("201")) +} + func TestPanicOnMissingDescription(t *testing.T) { routes := []*Route{ { diff --git a/pkg/ffapi/restfilter_json.go b/pkg/ffapi/restfilter_json.go index 64285f4..b465b5c 100644 --- a/pkg/ffapi/restfilter_json.go +++ b/pkg/ffapi/restfilter_json.go @@ -67,7 +67,11 @@ type FilterJSONKeyValues struct { } type FilterJSON struct { - Or []*FilterJSON `ffstruct:"FilterJSON" json:"or,omitempty"` + Or []*FilterJSON `ffstruct:"FilterJSON" json:"or,omitempty"` + FilterJSONOps +} + +type FilterJSONOps struct { Equal []*FilterJSONKeyValue `ffstruct:"FilterJSON" json:"equal,omitempty"` Eq []*FilterJSONKeyValue `ffstruct:"FilterJSON" json:"eq,omitempty"` // short name NEq []*FilterJSONKeyValue `ffstruct:"FilterJSON" json:"neq,omitempty"` // negated short name @@ -83,6 +87,7 @@ type FilterJSON struct { GTE []*FilterJSONKeyValue `ffstruct:"FilterJSON" json:"gte,omitempty"` // short name In []*FilterJSONKeyValues `ffstruct:"FilterJSON" json:"in,omitempty"` NIn []*FilterJSONKeyValues `ffstruct:"FilterJSON" json:"nin,omitempty"` // negated short name + Null []*FilterJSONBase `ffstruct:"FilterJSON" json:"null,omitempty"` } type QueryJSON struct { @@ -95,6 +100,35 @@ type QueryJSON struct { type SimpleFilterValue string +type resolveCtx struct { + ctx context.Context + jsonFilter *FilterJSON + valueResolver ValueResolverFn + skipFieldValidation bool + err error +} + +type ValueResolverFn func(ctx context.Context, level *FilterJSON, fieldName, suppliedValue string) (driver.Value, error) + +type JSONBuildFilterOpt struct { + valueResolver ValueResolverFn + skipFieldValidation bool +} + +// Option to add a handler that will be called at each OR level, before performing the normal +// processing on each +func ValueResolver(fn ValueResolverFn) *JSONBuildFilterOpt { + return &JSONBuildFilterOpt{ + valueResolver: fn, + } +} + +func SkipFieldValidation() *JSONBuildFilterOpt { + return &JSONBuildFilterOpt{ + skipFieldValidation: true, + } +} + func (js *SimpleFilterValue) UnmarshalJSON(b []byte) error { var v interface{} err := json.Unmarshal(b, &v) @@ -120,7 +154,7 @@ func (js SimpleFilterValue) String() string { return (string)(js) } -func (jq *QueryJSON) BuildFilter(ctx context.Context, qf QueryFactory) (Filter, error) { +func (jq *QueryJSON) BuildFilter(ctx context.Context, qf QueryFactory, options ...*JSONBuildFilterOpt) (Filter, error) { fb := qf.NewFilter(ctx) if jq.Count != nil { fb = fb.Count(*jq.Count) @@ -134,10 +168,13 @@ func (jq *QueryJSON) BuildFilter(ctx context.Context, qf QueryFactory) (Filter, for _, s := range jq.Sort { fb = fb.Sort(s) } - return jq.BuildSubFilter(ctx, fb, &jq.FilterJSON) + return (&jq.FilterJSON).BuildSubFilter(ctx, fb, options...) } -func validateFilterField(ctx context.Context, fb FilterBuilder, fieldAnyCase string) (string, error) { +func validateFilterField(ctx context.Context, fb FilterBuilder, fieldAnyCase string, rv *resolveCtx) (string, error) { + if rv.skipFieldValidation { + return fieldAnyCase, nil + } for _, f := range fb.Fields() { if strings.EqualFold(fieldAnyCase, f) { return f, nil @@ -146,64 +183,75 @@ func validateFilterField(ctx context.Context, fb FilterBuilder, fieldAnyCase str return "", i18n.NewError(ctx, i18n.MsgInvalidFilterField, fieldAnyCase) } -func (jq *QueryJSON) addSimpleFilters(ctx context.Context, fb FilterBuilder, jsonFilter *FilterJSON, andFilter AndFilter) (AndFilter, error) { - for _, e := range joinShortNames(jsonFilter.Equal, jsonFilter.Eq, jsonFilter.NEq) { - field, err := validateFilterField(ctx, fb, e.Field) +func (jf *FilterJSON) addSimpleFilters(ctx context.Context, fb FilterBuilder, andFilter AndFilter, rv *resolveCtx) (AndFilter, error) { + for _, e := range joinShortNames(jf.Equal, jf.Eq, jf.NEq) { + field, err := validateFilterField(ctx, fb, e.Field, rv) if err != nil { return nil, err } if e.CaseInsensitive { if e.Not { - andFilter = andFilter.Condition(fb.NIeq(field, e.Value.String())) + andFilter = andFilter.Condition(fb.NIeq(field, rv.resolve(field, e.Value.String()))) } else { - andFilter = andFilter.Condition(fb.IEq(field, e.Value.String())) + andFilter = andFilter.Condition(fb.IEq(field, rv.resolve(field, e.Value.String()))) } } else { if e.Not { - andFilter = andFilter.Condition(fb.Neq(field, e.Value.String())) + andFilter = andFilter.Condition(fb.Neq(field, rv.resolve(field, e.Value.String()))) } else { - andFilter = andFilter.Condition(fb.Eq(field, e.Value.String())) + andFilter = andFilter.Condition(fb.Eq(field, rv.resolve(field, e.Value.String()))) } } } - for _, e := range jsonFilter.Contains { - field, err := validateFilterField(ctx, fb, e.Field) + for _, e := range jf.Contains { + field, err := validateFilterField(ctx, fb, e.Field, rv) if err != nil { return nil, err } if e.CaseInsensitive { if e.Not { - andFilter = andFilter.Condition(fb.NotIContains(field, e.Value.String())) + andFilter = andFilter.Condition(fb.NotIContains(field, rv.resolve(field, e.Value.String()))) } else { - andFilter = andFilter.Condition(fb.IContains(field, e.Value.String())) + andFilter = andFilter.Condition(fb.IContains(field, rv.resolve(field, e.Value.String()))) } } else { if e.Not { - andFilter = andFilter.Condition(fb.NotContains(field, e.Value.String())) + andFilter = andFilter.Condition(fb.NotContains(field, rv.resolve(field, e.Value.String()))) } else { - andFilter = andFilter.Condition(fb.Contains(field, e.Value.String())) + andFilter = andFilter.Condition(fb.Contains(field, rv.resolve(field, e.Value.String()))) } } } - for _, e := range jsonFilter.StartsWith { - field, err := validateFilterField(ctx, fb, e.Field) + for _, e := range jf.StartsWith { + field, err := validateFilterField(ctx, fb, e.Field, rv) if err != nil { return nil, err } if e.CaseInsensitive { if e.Not { - andFilter = andFilter.Condition(fb.NotIStartsWith(field, e.Value.String())) + andFilter = andFilter.Condition(fb.NotIStartsWith(field, rv.resolve(field, e.Value.String()))) } else { - andFilter = andFilter.Condition(fb.IStartsWith(field, e.Value.String())) + andFilter = andFilter.Condition(fb.IStartsWith(field, rv.resolve(field, e.Value.String()))) } } else { if e.Not { - andFilter = andFilter.Condition(fb.NotStartsWith(field, e.Value.String())) + andFilter = andFilter.Condition(fb.NotStartsWith(field, rv.resolve(field, e.Value.String()))) } else { - andFilter = andFilter.Condition(fb.StartsWith(field, e.Value.String())) + andFilter = andFilter.Condition(fb.StartsWith(field, rv.resolve(field, e.Value.String()))) } } } + for _, e := range jf.Null { + field, err := validateFilterField(ctx, fb, e.Field, rv) + if err != nil { + return nil, err + } + if e.Not { + andFilter = andFilter.Condition(fb.Neq(field, nil)) + } else { + andFilter = andFilter.Condition(fb.Eq(field, nil)) + } + } return andFilter, nil } @@ -230,53 +278,98 @@ func joinInAndNin(in, nin []*FilterJSONKeyValues) []*FilterJSONKeyValues { return res } -func (jq *QueryJSON) BuildSubFilter(ctx context.Context, fb FilterBuilder, jsonFilter *FilterJSON) (Filter, error) { - andFilter, err := jq.addSimpleFilters(ctx, fb, jsonFilter, fb.And()) +func (rv *resolveCtx) resolve(fieldName string, v string) driver.Value { + if rv.valueResolver == nil { + return v + } + resolved, err := rv.valueResolver(rv.ctx, rv.jsonFilter, fieldName, v) + if err != nil { + rv.err = err + return "" + } + return resolved +} + +func (rv *resolveCtx) resolveMany(fieldName string, values []SimpleFilterValue) []driver.Value { + driverValues := make([]driver.Value, len(values)) + for i, v := range values { + driverValues[i] = rv.resolve(fieldName, v.String()) + } + return driverValues +} + +func buildResolveCtx(ctx context.Context, jsonFilter *FilterJSON, options ...*JSONBuildFilterOpt) *resolveCtx { + rv := &resolveCtx{ctx: ctx, jsonFilter: jsonFilter} + for _, o := range options { + if o.valueResolver != nil { + rv.valueResolver = o.valueResolver + } + if o.skipFieldValidation { + rv.skipFieldValidation = true + } + } + return rv +} + +func (jf *FilterJSON) BuildSubFilter(ctx context.Context, fb FilterBuilder, options ...*JSONBuildFilterOpt) (Filter, error) { + andFilter, err := jf.BuildAndFilter(ctx, fb, options...) if err != nil { return nil, err } - for _, e := range joinShortNames(jsonFilter.LessThan, jsonFilter.LT, nil) { - field, err := validateFilterField(ctx, fb, e.Field) + if len(andFilter.GetConditions()) == 1 { + return andFilter.GetConditions()[0], nil + } + return andFilter, nil +} + +func (jf *FilterJSON) BuildAndFilter(ctx context.Context, fb FilterBuilder, options ...*JSONBuildFilterOpt) (AndFilter, error) { + rv := buildResolveCtx(ctx, jf, options...) + andFilter, err := jf.addSimpleFilters(ctx, fb, fb.And(), rv) + if err != nil { + return nil, err + } + for _, e := range joinShortNames(jf.LessThan, jf.LT, nil) { + field, err := validateFilterField(ctx, fb, e.Field, rv) if err != nil { return nil, err } if e.CaseInsensitive || e.Not { return nil, i18n.NewError(ctx, i18n.MsgJSONQueryOpUnsupportedMod, "lessThan", allMods) } - andFilter = andFilter.Condition(fb.Lt(field, e.Value.String())) + andFilter = andFilter.Condition(fb.Lt(field, rv.resolve(field, e.Value.String()))) } - for _, e := range joinShortNames(jsonFilter.LessThanOrEqual, jsonFilter.LTE, nil) { - field, err := validateFilterField(ctx, fb, e.Field) + for _, e := range joinShortNames(jf.LessThanOrEqual, jf.LTE, nil) { + field, err := validateFilterField(ctx, fb, e.Field, rv) if err != nil { return nil, err } if e.CaseInsensitive || e.Not { return nil, i18n.NewError(ctx, i18n.MsgJSONQueryOpUnsupportedMod, "lessThanOrEqual", allMods) } - andFilter = andFilter.Condition(fb.Lte(field, e.Value.String())) + andFilter = andFilter.Condition(fb.Lte(field, rv.resolve(field, e.Value.String()))) } - for _, e := range joinShortNames(jsonFilter.GreaterThan, jsonFilter.GT, nil) { - field, err := validateFilterField(ctx, fb, e.Field) + for _, e := range joinShortNames(jf.GreaterThan, jf.GT, nil) { + field, err := validateFilterField(ctx, fb, e.Field, rv) if err != nil { return nil, err } if e.CaseInsensitive || e.Not { return nil, i18n.NewError(ctx, i18n.MsgJSONQueryOpUnsupportedMod, "greaterThan", allMods) } - andFilter = andFilter.Condition(fb.Gt(field, e.Value.String())) + andFilter = andFilter.Condition(fb.Gt(field, rv.resolve(field, e.Value.String()))) } - for _, e := range joinShortNames(jsonFilter.GreaterThanOrEqual, jsonFilter.GTE, nil) { - field, err := validateFilterField(ctx, fb, e.Field) + for _, e := range joinShortNames(jf.GreaterThanOrEqual, jf.GTE, nil) { + field, err := validateFilterField(ctx, fb, e.Field, rv) if err != nil { return nil, err } if e.CaseInsensitive || e.Not { return nil, i18n.NewError(ctx, i18n.MsgJSONQueryOpUnsupportedMod, "greaterThanOrEqual", allMods) } - andFilter = andFilter.Condition(fb.Gte(field, e.Value.String())) + andFilter = andFilter.Condition(fb.Gte(field, rv.resolve(field, e.Value.String()))) } - for _, e := range joinInAndNin(jsonFilter.In, jsonFilter.NIn) { - field, err := validateFilterField(ctx, fb, e.Field) + for _, e := range joinInAndNin(jf.In, jf.NIn) { + field, err := validateFilterField(ctx, fb, e.Field, rv) if err != nil { return nil, err } @@ -284,15 +377,15 @@ func (jq *QueryJSON) BuildSubFilter(ctx context.Context, fb FilterBuilder, jsonF return nil, i18n.NewError(ctx, i18n.MsgJSONQueryOpUnsupportedMod, "in", justCaseInsensitive) } if e.Not { - andFilter = andFilter.Condition(fb.NotIn(field, toDriverValues(e.Values))) + andFilter = andFilter.Condition(fb.NotIn(field, rv.resolveMany(field, e.Values))) } else { - andFilter = andFilter.Condition(fb.In(field, toDriverValues(e.Values))) + andFilter = andFilter.Condition(fb.In(field, rv.resolveMany(field, e.Values))) } } - if len(jsonFilter.Or) > 0 { + if len(jf.Or) > 0 { childFilter := fb.Or() - for _, child := range jsonFilter.Or { - subFilter, err := jq.BuildSubFilter(ctx, fb, child) + for _, child := range jf.Or { + subFilter, err := child.BuildSubFilter(ctx, fb, options...) if err != nil { return nil, err } @@ -304,16 +397,9 @@ func (jq *QueryJSON) BuildSubFilter(ctx context.Context, fb FilterBuilder, jsonF andFilter.Condition(childFilter) } } - if len(andFilter.GetConditions()) == 1 { - return andFilter.GetConditions()[0], nil + // Any error that occurred as part of the resolver plugin, need to be reconciled + if rv.err != nil { + return nil, rv.err } return andFilter, nil } - -func toDriverValues(values []SimpleFilterValue) []driver.Value { - driverValues := make([]driver.Value, len(values)) - for i, v := range values { - driverValues[i] = v.String() - } - return driverValues -} diff --git a/pkg/ffapi/restfilter_json_test.go b/pkg/ffapi/restfilter_json_test.go index 8ce9f49..d237c43 100644 --- a/pkg/ffapi/restfilter_json_test.go +++ b/pkg/ffapi/restfilter_json_test.go @@ -18,7 +18,9 @@ package ffapi import ( "context" + "database/sql/driver" "encoding/json" + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -52,6 +54,11 @@ func TestBuildQueryJSONNestedAndOr(t *testing.T) { "value": 999 } ], + "null": [ + { + "field": "cid" + } + ], "greaterThan": [ { "field": "sequence", @@ -101,7 +108,7 @@ func TestBuildQueryJSONNestedAndOr(t *testing.T) { fi, err := filter.Finalize() assert.NoError(t, err) - assert.Equal(t, "( tag == 'a' ) && ( masked == true ) && ( sequence != 999 ) && ( sequence >> 10 ) && ( ( ( masked == true ) && ( tag IN ['a','b','c'] ) && ( tag NI ['x','y'] ) && ( tag NI ['z'] ) ) || ( masked == false ) ) sort=tag,-sequence skip=5 limit=10", fi.String()) + assert.Equal(t, "( tag == 'a' ) && ( masked == true ) && ( sequence != 999 ) && ( cid == null ) && ( sequence >> 10 ) && ( ( ( masked == true ) && ( tag IN ['a','b','c'] ) && ( tag NI ['x','y'] ) && ( tag NI ['z'] ) ) || ( masked == false ) ) sort=tag,-sequence skip=5 limit=10", fi.String()) } func TestBuildQuerySingleNestedOr(t *testing.T) { @@ -130,6 +137,89 @@ func TestBuildQuerySingleNestedOr(t *testing.T) { assert.Equal(t, "tag == 'a'", fi.String()) } +func TestBuildQuerySkipFieldValidation(t *testing.T) { + + var jf *FilterJSON + err := json.Unmarshal([]byte(`{ + "equal": [ + { + "field": "anything at all", + "value": "a" + } + ] + }`), &jf) + assert.NoError(t, err) + + fb := TestQueryFactory.NewFilter(context.Background()) + andFilter, err := jf.BuildAndFilter(context.Background(), fb, SkipFieldValidation()) + assert.NoError(t, err) + conditions := andFilter.GetConditions() + assert.Len(t, conditions, 1) + + cond0 := conditions[0].ValueFilter() + assert.Equal(t, FilterOpEq, cond0.Op()) + assert.Equal(t, "anything at all", cond0.Field()) + assert.Equal(t, "a", cond0.Value()) + +} + +func TestBuildQuerySingleNestedWithResolverOk(t *testing.T) { + + var qf QueryJSON + err := json.Unmarshal([]byte(`{ + "or": [ + { + "equal": [ + { + "field": "tag", + "value": "a" + } + ] + } + ] + }`), &qf) + assert.NoError(t, err) + + filter, err := qf.BuildFilter(context.Background(), TestQueryFactory, ValueResolver(func(ctx context.Context, level *FilterJSON, fieldName, suppliedValue string) (driver.Value, error) { + assert.Equal(t, "tag", fieldName) + assert.Equal(t, "a", suppliedValue) + assert.Len(t, level.Equal, 1) + return "b", nil + })) + assert.NoError(t, err) + + fi, err := filter.Finalize() + assert.NoError(t, err) + + assert.Equal(t, "tag == 'b'", fi.String()) +} + +func TestBuildQuerySingleNestedWithResolverError(t *testing.T) { + + var qf QueryJSON + err := json.Unmarshal([]byte(`{ + "or": [ + { + "in": [ + { + "field": "tag", + "values": ["a"] + } + ] + } + ] + }`), &qf) + assert.NoError(t, err) + + _, err = qf.BuildFilter(context.Background(), TestQueryFactory, ValueResolver(func(ctx context.Context, level *FilterJSON, fieldName, suppliedValue string) (driver.Value, error) { + assert.Equal(t, "tag", fieldName) + assert.Equal(t, "a", suppliedValue) + assert.Len(t, level.In, 1) + return "", fmt.Errorf("pop") + })) + assert.Regexp(t, "pop", err) +} + func TestBuildQueryJSONEqual(t *testing.T) { var qf QueryJSON @@ -162,6 +252,12 @@ func TestBuildQueryJSONEqual(t *testing.T) { "field": "tag", "value": "abc" } + ], + "null": [ + { + "not": true, + "field": "cid" + } ] }`), &qf) assert.NoError(t, err) @@ -172,7 +268,7 @@ func TestBuildQueryJSONEqual(t *testing.T) { fi, err := filter.Finalize() assert.NoError(t, err) - assert.Equal(t, "( created == 0 ) && ( tag != 'abc' ) && ( tag := 'ABC' ) && ( tag ;= 'abc' ) sort=tag,sequence skip=5 limit=10 count=true", fi.String()) + assert.Equal(t, "( created == 0 ) && ( tag != 'abc' ) && ( tag := 'ABC' ) && ( tag ;= 'abc' ) && ( cid != null ) sort=tag,sequence skip=5 limit=10 count=true", fi.String()) } func TestBuildQueryJSONContains(t *testing.T) { @@ -480,6 +576,12 @@ func TestBuildQueryJSONBadFields(t *testing.T) { assert.NoError(t, err) _, err = qf8.BuildFilter(context.Background(), TestQueryFactory) assert.Regexp(t, "FF00142", err) + + var qf9 QueryJSON + err = json.Unmarshal([]byte(`{"null": [{"field": "wrong"}]}`), &qf9) + assert.NoError(t, err) + _, err = qf9.BuildFilter(context.Background(), TestQueryFactory) + assert.Regexp(t, "FF00142", err) } func TestBuildQueryJSONDocumented(t *testing.T) { diff --git a/pkg/ffapi/routes.go b/pkg/ffapi/routes.go index 98b9203..12578ef 100644 --- a/pkg/ffapi/routes.go +++ b/pkg/ffapi/routes.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -18,6 +18,7 @@ package ffapi import ( "context" + "io" "github.com/getkin/kin-openapi/openapi3" "github.com/hyperledger/firefly-common/pkg/config" @@ -61,12 +62,16 @@ type Route struct { JSONOutputSchema func(ctx context.Context, schemaGen SchemaGenerator) (*openapi3.SchemaRef, error) // JSONOutputValue is a function that returns a pointer to a structure to take JSON output JSONOutputValue func() interface{} - // JSONOutputCodes is the success response code + // JSONOutputCodes is the success response codes that could be returned by the API. Error codes are explicitly not supported by the framework since they could be subject to change by the errors thrown or how errors are handled. JSONOutputCodes []int - // JSONHandler is a function for handling JSON content type input. Input/Ouptut objects are returned by JSONInputValue/JSONOutputValue funcs + // JSONHandler is a function for handling JSON content type input. Input/Output objects are returned by JSONInputValue/JSONOutputValue funcs JSONHandler func(r *APIRequest) (output interface{}, err error) // FormUploadHandler takes a single file upload, and returns a JSON object FormUploadHandler func(r *APIRequest) (output interface{}, err error) + // StreamHandler allows for custom request handling with explicit stream (io.ReadCloser) responses + StreamHandler func(r *APIRequest) (output io.ReadCloser, err error) + // CustomResponseRefs allows for specifying custom responses for a route + CustomResponseRefs map[string]*openapi3.ResponseRef // Deprecated whether this route is deprecated Deprecated bool // Tag a category identifier for this route in the generated OpenAPI spec diff --git a/pkg/fftypes/jsonobject.go b/pkg/fftypes/jsonobject.go index 682da86..4598dbf 100644 --- a/pkg/fftypes/jsonobject.go +++ b/pkg/fftypes/jsonobject.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -133,8 +133,8 @@ func (jd JSONObject) GetObject(key string) JSONObject { } func (jd JSONObject) GetObjectOk(key string) (JSONObject, bool) { - vInterace, ok := jd[key] - if ok && vInterace != nil { + vInterface, ok := jd[key] + if ok && vInterface != nil { vInterface := jd[key] switch vMap := vInterface.(type) { case map[string]interface{}: @@ -142,7 +142,7 @@ func (jd JSONObject) GetObjectOk(key string) (JSONObject, bool) { case JSONObject: return vMap, true default: - log.L(context.Background()).Errorf("Invalid object value '%+v' for key '%s'", vInterace, key) + log.L(context.Background()).Errorf("Invalid object value '%+v' for key '%s'", vInterface, key) return JSONObject{}, false // Ensures a non-nil return } } @@ -187,11 +187,10 @@ func (jd JSONObject) GetObjectArray(key string) JSONObjectArray { } func (jd JSONObject) GetObjectArrayOk(key string) (JSONObjectArray, bool) { - vInterace, ok := jd[key] - if ok && vInterace != nil { - return ToJSONObjectArray(vInterace) + vInterface, ok := jd[key] + if ok && vInterface != nil { + return ToJSONObjectArray(vInterface) } - log.L(context.Background()).Errorf("Invalid object value '%+v' for key '%s'", vInterace, key) return JSONObjectArray{}, false // Ensures a non-nil return } @@ -201,11 +200,11 @@ func (jd JSONObject) GetStringArray(key string) []string { } func (jd JSONObject) GetStringArrayOk(key string) ([]string, bool) { - vInterace, ok := jd[key] - if ok && vInterace != nil { - return ToStringArray(vInterace) + vInterface, ok := jd[key] + if ok && vInterface != nil { + return ToStringArray(vInterface) } - log.L(context.Background()).Errorf("Invalid string array value '%+v' for key '%s'", vInterace, key) + log.L(context.Background()).Errorf("Invalid string array value '%+v' for key '%s'", vInterface, key) return []string{}, false // Ensures a non-nil return } diff --git a/pkg/i18n/en_base_error_messages.go b/pkg/i18n/en_base_error_messages.go index de96e7e..2749cfc 100644 --- a/pkg/i18n/en_base_error_messages.go +++ b/pkg/i18n/en_base_error_messages.go @@ -178,5 +178,7 @@ var ( MsgJSONQueryOpUnsupportedMod = ffe("FF00240", "Operation '%s' does not support modifiers: %v", 400) MsgJSONQueryValueUnsupported = ffe("FF00241", "Field value not supported (must be string, number, or boolean): %s", 400) MsgJSONQuerySortUnsupported = ffe("FF00242", "Invalid 'order' for sort (must be 'asc', 'ascending', 'desc' or 'descending'): %s", 400) - MsgDBExecFailed = ffe("FF00243", "Database update failed") + MsgWebSocketBatchInflight = ffe("FF00243", "Stream '%s' already has batch '%d' inflight on websocket connection '%s'") + MsgWebSocketRoundTripTimeout = ffe("FF00244", "Timed out or cancelled waiting for acknowledgement") + MsgDBExecFailed = ffe("FF00245", "Database update failed") ) diff --git a/pkg/i18n/en_base_field_descriptions.go b/pkg/i18n/en_base_field_descriptions.go index 19f1455..7dde0ac 100644 --- a/pkg/i18n/en_base_field_descriptions.go +++ b/pkg/i18n/en_base_field_descriptions.go @@ -37,6 +37,7 @@ var ( FilterJSONLTE = ffm("FilterJSON.lte", "Short name for lessThanOrEqual") FilterJSONIn = ffm("FilterJSON.in", "Array of field + values-array combinations to apply as 'in' filters (matching one of a set of values) - all filters must match") FilterJSONNIn = ffm("FilterJSON.nin", "Shortcut for in with all conditions negated (the not property of all children is overridden)") + FilterJSONNull = ffm("FilterJSON.null", "Tests if the specified field is null (unset)") FilterJSONLimit = ffm("FilterJSON.limit", "Limit on the results to return") FilterJSONSkip = ffm("FilterJSON.skip", "Number of results to skip before returning entries, for skip+limit based pagination") FilterJSONSort = ffm("FilterJSON.sort", "Array of fields to sort by. A '-' prefix on a field requests that field is sorted in descending order") diff --git a/pkg/wsserver/config.go b/pkg/wsserver/config.go new file mode 100644 index 0000000..5ecf1f0 --- /dev/null +++ b/pkg/wsserver/config.go @@ -0,0 +1,49 @@ +// Copyright © 2024 Kaleido, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wsserver + +import ( + "time" + + "github.com/hyperledger/firefly-common/pkg/config" +) + +type WebSocketServerConfig struct { + AckTimeout time.Duration + ReadBufferSize int64 + WriteBufferSize int64 +} + +const ( + ConfigAckTimeout = "ackTimeout" + ConfigReadBufferSize = "readBufferSize" + ConfigWriteBufferSize = "writeBufferSize" +) + +func InitConfig(conf config.Section) { + conf.AddKnownKey(ConfigAckTimeout, "2m") + conf.AddKnownKey(ConfigReadBufferSize, "4KB") + conf.AddKnownKey(ConfigWriteBufferSize, "4KB") +} + +func GenerateConfig(conf config.Section) *WebSocketServerConfig { + return &WebSocketServerConfig{ + AckTimeout: conf.GetDuration(ConfigAckTimeout), + ReadBufferSize: conf.GetByteSize(ConfigReadBufferSize), + WriteBufferSize: conf.GetByteSize(ConfigWriteBufferSize), + } +} diff --git a/pkg/wsserver/config_test.go b/pkg/wsserver/config_test.go new file mode 100644 index 0000000..700a590 --- /dev/null +++ b/pkg/wsserver/config_test.go @@ -0,0 +1,40 @@ +// Copyright © 2023 Kaleido, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wsserver + +import ( + "testing" + "time" + + "github.com/hyperledger/firefly-common/pkg/config" + "gotest.tools/assert" +) + +func TestGenerateConfigTLS(t *testing.T) { + + config.RootConfigReset() + conf := config.RootSection("ut") + InitConfig(conf) + + configObj := GenerateConfig(conf) + assert.Equal(t, WebSocketServerConfig{ + AckTimeout: 2 * time.Minute, + ReadBufferSize: 4096, + WriteBufferSize: 4096, + }, *configObj) + +} diff --git a/pkg/wsserver/wsconn.go b/pkg/wsserver/wsconn.go index 89f986c..bb51d87 100644 --- a/pkg/wsserver/wsconn.go +++ b/pkg/wsserver/wsconn.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -18,7 +18,6 @@ package wsserver import ( "context" - "reflect" "strings" "sync" @@ -29,21 +28,14 @@ import ( ) type webSocketConnection struct { - ctx context.Context - id string - server *webSocketServer - conn *ws.Conn - mux sync.Mutex - closed bool - streams map[string]*webSocketStream - broadcast chan interface{} - newStream chan bool - closing chan struct{} -} - -type WebSocketCommandMessageOrError struct { - Msg *WebSocketCommandMessage - Err error + ctx context.Context + id string + server *webSocketServer + conn *ws.Conn + closeMux sync.Mutex + closed bool + send chan interface{} + closing chan struct{} } type WebSocketCommandMessage struct { @@ -56,14 +48,12 @@ type WebSocketCommandMessage struct { func newConnection(bgCtx context.Context, server *webSocketServer, conn *ws.Conn) *webSocketConnection { id := fftypes.NewUUID().String() wsc := &webSocketConnection{ - ctx: log.WithLogField(bgCtx, "wsc", id), - id: id, - server: server, - conn: conn, - newStream: make(chan bool), - streams: make(map[string]*webSocketStream), - broadcast: make(chan interface{}), - closing: make(chan struct{}), + ctx: log.WithLogField(bgCtx, "wsc", id), + id: id, + server: server, + conn: conn, + send: make(chan interface{}), + closing: make(chan struct{}), } go wsc.listen() go wsc.sender() @@ -71,66 +61,31 @@ func newConnection(bgCtx context.Context, server *webSocketServer, conn *ws.Conn } func (c *webSocketConnection) close() { - c.mux.Lock() + c.closeMux.Lock() if !c.closed { c.closed = true c.conn.Close() close(c.closing) } - c.mux.Unlock() + c.closeMux.Unlock() - for _, t := range c.streams { - c.server.cycleStream(c.id, t) - log.L(c.ctx).Infof("Websocket closed while active on stream '%s'", t.streamName) - } c.server.connectionClosed(c) log.L(c.ctx).Infof("Disconnected") } func (c *webSocketConnection) sender() { defer c.close() - buildCases := func() []reflect.SelectCase { - c.mux.Lock() - defer c.mux.Unlock() - cases := make([]reflect.SelectCase, len(c.streams)+3) - i := 0 - for _, t := range c.streams { - cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(t.senderChannel)} - i++ - } - cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(c.broadcast)} - i++ - cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(c.closing)} - i++ - cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(c.newStream)} - return cases - } - cases := buildCases() for { - chosen, value, ok := reflect.Select(cases) - if !ok { + select { + case payload := <-c.send: + if err := c.conn.WriteJSON(payload); err != nil { + log.L(c.ctx).Errorf("Send failed - closing connection: %s", err) + return + } + case <-c.closing: log.L(c.ctx).Infof("Closing") return } - - if chosen == len(cases)-1 { - // Addition of a new stream - cases = buildCases() - } else { - // Message from one of the existing streams - _ = c.conn.WriteJSON(value.Interface()) - } - } -} - -func (c *webSocketConnection) startStream(t *webSocketStream) { - c.mux.Lock() - c.streams[t.streamName] = t - c.server.StreamStarted(c, t.streamName) - c.mux.Unlock() - select { - case c.newStream <- true: - case <-c.closing: } } @@ -145,39 +100,15 @@ func (c *webSocketConnection) listen() { return } log.L(c.ctx).Tracef("Received: %+v", msg) - - t := c.server.getStream(msg.Stream) switch strings.ToLower(msg.Type) { case "start": - c.startStream(t) + c.server.streamStarted(c, msg.Stream) case "ack": - if !c.dispatchAckOrError(t, &msg, nil) { - return - } + c.server.completeRoundTrip(msg.Stream, &msg, nil) case "error", "nack": - if !c.dispatchAckOrError(t, &msg, i18n.NewError(c.ctx, i18n.MsgWSErrorFromClient, msg.Message)) { - return - } + c.server.completeRoundTrip(msg.Stream, &msg, i18n.NewError(c.ctx, i18n.MsgWSErrorFromClient, msg.Message)) default: log.L(c.ctx).Errorf("Unexpected message type: %+v", msg) } } } - -func (c *webSocketConnection) dispatchAckOrError(t *webSocketStream, msg *WebSocketCommandMessage, err error) bool { - if err != nil { - log.L(c.ctx).Debugf("Received WebSocket error on stream '%s': %s", t.streamName, err) - } else { - log.L(c.ctx).Debugf("Received WebSocket ack for batch %d on stream '%s'", msg.BatchNumber, t.streamName) - } - select { - case t.receiverChannel <- &WebSocketCommandMessageOrError{Msg: msg, Err: err}: - default: - log.L(c.ctx).Debugf("Received WebSocket ack for batch %d on stream '%s'. Too many spurious acks - closing connection", msg.BatchNumber, t.streamName) - // This shouldn't happen, as the channel has a buffer. So this means the client has sent us a number of - // acks that are not on the right stream (so no event stream is attached). - // We cannot discard this ack and continue, but we cannot afford to block here either, so we close the websocket - return false - } - return true -} diff --git a/pkg/wsserver/wsserver.go b/pkg/wsserver/wsserver.go index d8bb009..ff8b8c2 100644 --- a/pkg/wsserver/wsserver.go +++ b/pkg/wsserver/wsserver.go @@ -1,4 +1,4 @@ -// Copyright © 2023 Kaleido, Inc. +// Copyright © 2024 Kaleido, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -19,63 +19,80 @@ package wsserver import ( "context" "net/http" - "reflect" "sync" - "time" "github.com/gorilla/websocket" "github.com/hyperledger/firefly-common/pkg/i18n" "github.com/hyperledger/firefly-common/pkg/log" ) -// WebSocketChannels is provided to allow us to do a blocking send to a namespace that will complete once a client connects on it -// We also provide a channel to listen on for closing of the connection, to allow a select to wake on a blocking send -type WebSocketChannels interface { - GetChannels(streamName string) (senderChannel chan<- interface{}, broadcastChannel chan<- interface{}, receiverChannel <-chan *WebSocketCommandMessageOrError) +// The Protocol interface layers a protocol on top of raw websockets, that allows the server side to: +// - Model the concept of multiple "streams" on a single WebSocket +// - Block until 1 or more connections are available that have "started" a particular stream +// - Send a broadcast to all connections on a stream +// - Send a single payload to a single selected connection on a stream, and wait for an "ack" back +// from that specific websocket connection (or that websocket connection to disconnect) +// +// NOTE: This replaces a previous WebSocketChannels interface, which started its life in 2018 +// +// and attempted to solve the above problem set in a different way that had a challenging timing issue. +type Protocol interface { + // Broadcast performs best-effort delivery to all connections currently active on the specified stream + Broadcast(ctx context.Context, stream string, payload interface{}) + + // NextRoundTrip blocks until at least one connection is started on the stream, and then + // returns an interface that can be used to send a payload to exactly one of the attached + // connections, and receive an ack/error from just the one connection that was picked. + // - Returns an error if the context is closed. + RoundTrip(ctx context.Context, stream string, payload WSBatch) (*WebSocketCommandMessage, error) +} + +// WSBatch is any serializable structure that contains the batch header +type WSBatch interface { + GetBatchHeader() *BatchHeader +} + +type BatchHeader struct { + BatchNumber int64 `json:"batchNumber"` + Stream string `json:"stream"` } // WebSocketServer is the full server interface with the init call type WebSocketServer interface { - WebSocketChannels + Protocol Handler(w http.ResponseWriter, r *http.Request) Close() } -type webSocketServer struct { - ctx context.Context - processingTimeout time.Duration - mux sync.Mutex - streams map[string]*webSocketStream - streamMap map[string]map[string]*webSocketConnection - newStream chan bool - replyChannel chan interface{} - upgrader *websocket.Upgrader - connections map[string]*webSocketConnection +type streamState struct { + wlmCounter int64 + inflight *roundTrip + conns []*webSocketConnection } -type webSocketStream struct { - streamName string - senderChannel chan interface{} - broadcastChannel chan interface{} - receiverChannel chan *WebSocketCommandMessageOrError +type webSocketServer struct { + ctx context.Context + conf WebSocketServerConfig + streamMap map[string]*streamState + streamMapChange chan struct{} + mux sync.Mutex + upgrader *websocket.Upgrader + connections map[string]*webSocketConnection } // NewWebSocketServer create a new server with a simplified interface -func NewWebSocketServer(bgCtx context.Context) WebSocketServer { +func NewWebSocketServer(bgCtx context.Context, config *WebSocketServerConfig) WebSocketServer { s := &webSocketServer{ - ctx: bgCtx, - connections: make(map[string]*webSocketConnection), - streams: make(map[string]*webSocketStream), - streamMap: make(map[string]map[string]*webSocketConnection), - newStream: make(chan bool), - replyChannel: make(chan interface{}), - processingTimeout: 30 * time.Second, + ctx: bgCtx, + connections: make(map[string]*webSocketConnection), + streamMap: make(map[string]*streamState), + streamMapChange: make(chan struct{}), + conf: *config, upgrader: &websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, + ReadBufferSize: int(config.ReadBufferSize), + WriteBufferSize: int(config.WriteBufferSize), }, } - go s.processBroadcasts() return s } @@ -91,25 +108,22 @@ func (s *webSocketServer) Handler(w http.ResponseWriter, r *http.Request) { s.connections[c.id] = c } -func (s *webSocketServer) cycleStream(connInfo string, t *webSocketStream) { - s.mux.Lock() - defer s.mux.Unlock() - - // When a connection that was listening on a stream closes, we need to wake anyone - // that was listening for a response - select { - case t.receiverChannel <- &WebSocketCommandMessageOrError{Err: i18n.NewError(s.ctx, i18n.MsgWebSocketClosed, connInfo)}: - default: - } -} - func (s *webSocketServer) connectionClosed(c *webSocketConnection) { s.mux.Lock() defer s.mux.Unlock() delete(s.connections, c.id) - for _, stream := range c.streams { - delete(s.streamMap[stream.streamName], c.id) + for _, ss := range s.streamMap { + // Remove the connection + newSSConns := make([]*webSocketConnection, 0, len(ss.conns)) + for _, ssConn := range ss.conns { + if ssConn.id != c.id { + newSSConns = append(newSSConns, ssConn) + } + } + ss.conns = newSSConns } + close(s.streamMapChange) + s.streamMapChange = make(chan struct{}) } func (s *webSocketServer) Close() { @@ -118,92 +132,159 @@ func (s *webSocketServer) Close() { } } -func (s *webSocketServer) getStream(stream string) *webSocketStream { +func (s *webSocketServer) Broadcast(ctx context.Context, stream string, payload interface{}) { s.mux.Lock() - t, exists := s.streams[stream] - if !exists { - t = &webSocketStream{ - streamName: stream, - senderChannel: make(chan interface{}), - broadcastChannel: make(chan interface{}), - receiverChannel: make(chan *WebSocketCommandMessageOrError, 10), - } - s.streams[stream] = t - s.streamMap[stream] = make(map[string]*webSocketConnection) + ss := s.streamMap[stream] + if ss == nil { + ss = &streamState{} + s.streamMap[stream] = ss } + conns := make([]*webSocketConnection, len(ss.conns)) + copy(conns, ss.conns) s.mux.Unlock() - if !exists { - // Signal to the broadcaster that a new stream has been added - s.newStream <- true + for _, c := range conns { + select { + case c.send <- payload: + case <-c.closing: + // This isn't an error, just move on + log.L(ctx).Warnf("broadcast failed to closing connection '%s'", c.id) + } } - return t } -func (s *webSocketServer) GetChannels(stream string) (chan<- interface{}, chan<- interface{}, <-chan *WebSocketCommandMessageOrError) { - t := s.getStream(stream) - return t.senderChannel, t.broadcastChannel, t.receiverChannel +type roundTrip struct { + ss *streamState + conn *webSocketConnection + batchNumber int64 + done chan struct{} + err error + response *WebSocketCommandMessage } -func (s *webSocketServer) StreamStarted(c *webSocketConnection, stream string) { - // Track that this connection is interested in this stream - s.streamMap[stream][c.id] = c -} +func (s *webSocketServer) RoundTrip(ctx context.Context, stream string, payload WSBatch) (*WebSocketCommandMessage, error) { + var rt *roundTrip + err := s.waitStreamConnections(ctx, stream, func(ss *streamState) error { + // If there's an inflight already, that's an error - the caller is required to call NextRoundTrip sequentially, + // and always handle the cleanup of the RoundTripper + if ss.inflight != nil { + return i18n.NewError(s.ctx, i18n.MsgWebSocketBatchInflight, stream, ss.inflight.batchNumber, ss.inflight.conn.id) + } + // Do a round-robbin pick of one of the connections + conn := ss.conns[ss.wlmCounter%int64(len(ss.conns))] + ss.wlmCounter++ + // The wlmCounter is used as the batch number in the payload + payload.GetBatchHeader().BatchNumber = ss.wlmCounter + payload.GetBatchHeader().Stream = stream + rt = &roundTrip{ + ss: ss, + conn: conn, + batchNumber: ss.wlmCounter, // batch number increments in this library + done: make(chan struct{}), + } + ss.inflight = rt + return nil + }) + if err != nil { + return nil, err + } -func (s *webSocketServer) processBroadcasts() { - var streams []string - buildCases := func() []reflect.SelectCase { - // only hold the lock while we're building the list of cases (not while doing the select) + // Ensure we clean up before returning, including in error cases + defer func() { s.mux.Lock() - defer s.mux.Unlock() - streams = make([]string, len(s.streams)) - cases := make([]reflect.SelectCase, len(s.streams)+1) - i := 0 - for _, t := range s.streams { - cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(t.broadcastChannel)} - streams[i] = t.streamName - i++ - } - cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(s.newStream)} - return cases + rt.ss.inflight = nil + s.mux.Unlock() + }() + + // Send the payload to initiate the exchange + select { + case rt.conn.send <- payload: + case <-rt.conn.closing: + // handle connection closure while waiting to send + return nil, i18n.NewError(s.ctx, i18n.MsgWebSocketClosed, rt.conn.id) + case <-ctx.Done(): + // ... or stop/reset of stream, or server shutdown + return nil, i18n.NewError(s.ctx, i18n.MsgContextCanceled) } - cases := buildCases() - for { - chosen, value, ok := reflect.Select(cases) - if !ok { - log.L(s.ctx).Warn("An error occurred broadcasting the message") - return - } - if chosen == len(cases)-1 { - // Addition of a new stream - cases = buildCases() - } else { - // Message on one of the existing streams - // Gather all connections interested in this stream and send to them - s.mux.Lock() - stream := streams[chosen] - wsconns := getConnListFromMap(s.streamMap[stream]) - s.mux.Unlock() - s.broadcastToConnections(wsconns, value.Interface()) - } + // Set the processing timeout to wait for batch acknowledgement + ctxWithTimeout, cancelTimeout := context.WithTimeout(ctx, s.conf.AckTimeout) + defer cancelTimeout() + + // Wait for the response + select { + case <-rt.done: + case <-rt.conn.closing: + // handle connection closure while waiting for ack + return nil, i18n.NewError(s.ctx, i18n.MsgWebSocketClosed, rt.conn.id) + case <-ctxWithTimeout.Done(): + // ... or time out, stop/reset of stream, or server shutdown + return nil, i18n.NewError(s.ctx, i18n.MsgWebSocketRoundTripTimeout) } + return rt.response, rt.err } -// getConnListFromMap is a simple helper to snapshot a map into a list, which can be called with a short-lived lock -func getConnListFromMap(tm map[string]*webSocketConnection) []*webSocketConnection { - wsconns := make([]*webSocketConnection, 0, len(tm)) - for _, c := range tm { - wsconns = append(wsconns, c) +func (s *webSocketServer) completeRoundTrip(stream string, msg *WebSocketCommandMessage, err error) { + s.mux.Lock() + defer s.mux.Unlock() + ss := s.streamMap[stream] + if ss == nil || ss.inflight == nil { + log.L(s.ctx).Warnf("Received spurious ack for batchNumber=%d while no batch in-flight for stream '%s'", msg.BatchNumber, stream) + return } - return wsconns + rt := ss.inflight + // We accept batchNumber: 0 (omitted) as an nack/ack for the last thing sent + if msg.BatchNumber > 0 && rt.batchNumber != msg.BatchNumber { + log.L(s.ctx).Warnf("Received spurious ack for batchNumber=%d while batchNumber=%d in-flight for stream '%s'", msg.BatchNumber, rt.batchNumber, stream) + return + } + rt.response = msg + rt.err = err + close(rt.done) + // We are NOT responsible for clearing ss.inflight - that is the RoundTrip() function's job } -func (s *webSocketServer) broadcastToConnections(connections []*webSocketConnection, message interface{}) { - for _, c := range connections { +// waits until at least one connection is started on the requested stream, and returns an +// snapshot list of all connections on that stream. +func (s *webSocketServer) waitStreamConnections(ctx context.Context, stream string, lockedCallback func(ss *streamState) error) error { + for { + // check if there are connections + s.mux.Lock() + streamMapChange := s.streamMapChange + ss := s.streamMap[stream] + if ss != nil && len(ss.conns) > 0 { + err := lockedCallback(ss) + s.mux.Unlock() + return err + } + s.mux.Unlock() select { - case c.broadcast <- message: - case <-c.closing: - log.L(s.ctx).Warnf("Connection %s closed while attempting to deliver reply", c.id) + case <-streamMapChange: + case <-ctx.Done(): + return i18n.NewError(ctx, i18n.MsgContextCanceled) } } } + +func (s *webSocketServer) streamStarted(c *webSocketConnection, stream string) { + // Track that this connection is interested in this stream + s.mux.Lock() + defer s.mux.Unlock() + ss := s.streamMap[stream] + if ss == nil { + ss = &streamState{} + s.streamMap[stream] = ss + } + // Ignore duplicate starts on the same connection + found := false + for _, existing := range ss.conns { + if existing.id == c.id { + found = true + } + } + if !found { + ss.conns = append(ss.conns, c) + } + // Notify anyone waiting for a connection, and setup the next waiter + close(s.streamMapChange) + s.streamMapChange = make(chan struct{}) +} diff --git a/pkg/wsserver/wsserver_test.go b/pkg/wsserver/wsserver_test.go index 9c3a864..14452fe 100644 --- a/pkg/wsserver/wsserver_test.go +++ b/pkg/wsserver/wsserver_test.go @@ -21,26 +21,42 @@ import ( "net/http" "net/http/httptest" "net/url" - "sync" "testing" "time" ws "github.com/gorilla/websocket" + "github.com/hyperledger/firefly-common/pkg/config" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" ) +type testPayload struct { + BatchHeader + Message string +} + +func (tp *testPayload) GetBatchHeader() *BatchHeader { + return &tp.BatchHeader +} + func newTestWebSocketServer() (*webSocketServer, *httptest.Server) { - s := NewWebSocketServer(context.Background()).(*webSocketServer) + config.RootConfigReset() + utConf := config.RootSection("ut") + InitConfig(utConf) + s := NewWebSocketServer(context.Background(), GenerateConfig(utConf)).(*webSocketServer) ts := httptest.NewServer(http.HandlerFunc(s.Handler)) return s, ts } -func TestConnectSendReceiveCycle(t *testing.T) { +func TestRoundTripStartFirstGood(t *testing.T) { + logrus.SetLevel(logrus.TraceLevel) + assert := assert.New(t) w, ts := newTestWebSocketServer() defer ts.Close() + defer w.Close() u, err := url.Parse(ts.URL) u.Scheme = "ws" @@ -48,43 +64,180 @@ func TestConnectSendReceiveCycle(t *testing.T) { c, _, err := ws.DefaultDialer.Dial(u.String(), nil) assert.NoError(err) + c.WriteJSON(&WebSocketCommandMessage{ + Type: "start", + // note this test runs on the "" (empty string) stream + }) + + // Double start will be ignored c.WriteJSON(&WebSocketCommandMessage{ Type: "start", }) - s, _, r := w.GetChannels("") + // Wait in this test until we've locked in the start + err = w.waitStreamConnections(w.ctx, "", func(ss *streamState) error { return nil }) + assert.NoError(err) - s <- "Hello World" + for i := 1; i <= 10; i++ { + + roundTripComplete := make(chan struct{}) + go func() { + cm, err := w.RoundTrip(w.ctx, "", &testPayload{Message: "Hello World"}) + assert.NoError(err) + assert.Equal("good ack", cm.Message) + defer close(roundTripComplete) + }() + + var received testPayload + err = c.ReadJSON(&received) + assert.NoError(err) + assert.Equal(int64(i), received.BatchNumber) + assert.Equal("Hello World", received.Message) + + c.WriteJSON(&WebSocketCommandMessage{ + Type: "ignoreme", + }) + + c.WriteJSON(&WebSocketCommandMessage{ + Type: "ack", + BatchNumber: 12345, + Message: "bad ack", + }) + + c.WriteJSON(&WebSocketCommandMessage{ + Type: "ack", + BatchNumber: received.BatchNumber, + Message: "good ack", + }) + <-roundTripComplete - var val string - c.ReadJSON(&val) - assert.Equal("Hello World", val) + } + +} + +func TestRoundTripStartWhileRoundTripWaitingGood(t *testing.T) { + logrus.SetLevel(logrus.TraceLevel) + + assert := assert.New(t) + + w, ts := newTestWebSocketServer() + defer ts.Close() + defer w.Close() + + u, err := url.Parse(ts.URL) + u.Scheme = "ws" + u.Path = "/ws" + c, _, err := ws.DefaultDialer.Dial(u.String(), nil) + assert.NoError(err) + + roundTripComplete := make(chan struct{}) + go func() { + cm, err := w.RoundTrip(w.ctx, "stream1", &testPayload{Message: "Hello World"}) + assert.NoError(err) + assert.Equal("good ack", cm.Message) + defer close(roundTripComplete) + }() + + time.Sleep(50 * time.Millisecond) c.WriteJSON(&WebSocketCommandMessage{ - Type: "ignoreme", + Type: "start", + Stream: "stream1", }) + var received testPayload + err = c.ReadJSON(&received) + assert.NoError(err) + assert.Equal(int64(1), received.BatchNumber) + assert.Equal("Hello World", received.Message) + c.WriteJSON(&WebSocketCommandMessage{ - Type: "ack", + Type: "ack", + BatchNumber: received.BatchNumber, + Message: "good ack", + Stream: "stream1", }) - msgOrErr := <-r - assert.NoError(msgOrErr.Err) + <-roundTripComplete +} - s <- "Don't Panic!" +func TestRoundTripInFlight(t *testing.T) { + logrus.SetLevel(logrus.TraceLevel) - c.ReadJSON(&val) - assert.Equal("Don't Panic!", val) + assert := assert.New(t) + + w, ts := newTestWebSocketServer() + defer ts.Close() + + u, err := url.Parse(ts.URL) + u.Scheme = "ws" + u.Path = "/ws" + c, _, err := ws.DefaultDialer.Dial(u.String(), nil) + assert.NoError(err) c.WriteJSON(&WebSocketCommandMessage{ - Type: "error", - Message: "Panic!", + Type: "start", + Stream: "stream1", }) - msgOrErr = <-r - assert.Regexp("Error received from WebSocket client: Panic!", msgOrErr.Err) + roundTripComplete := make(chan struct{}) + go func() { + defer close(roundTripComplete) + _, err := w.RoundTrip(w.ctx, "stream1", &testPayload{Message: "Hello World"}) + assert.Regexp("FF00228", err) + }() + + var received testPayload + err = c.ReadJSON(&received) + assert.NoError(err) + + _, err = w.RoundTrip(w.ctx, "stream1", &testPayload{Message: "Hello World"}) + assert.Regexp("FF00243", err) w.Close() + <-roundTripComplete +} + +func TestRoundTripClientNack(t *testing.T) { + logrus.SetLevel(logrus.TraceLevel) + + assert := assert.New(t) + + w, ts := newTestWebSocketServer() + defer ts.Close() + defer w.Close() + + u, err := url.Parse(ts.URL) + u.Scheme = "ws" + u.Path = "/ws" + c, _, err := ws.DefaultDialer.Dial(u.String(), nil) + assert.NoError(err) + + roundTripComplete := make(chan struct{}) + go func() { + cm, err := w.RoundTrip(w.ctx, "stream1", &testPayload{Message: "Don't Panic!"}) + assert.Regexp("Error received from WebSocket client: Panic!", err) + assert.Equal("Panic!", cm.Message) + defer close(roundTripComplete) + }() + + c.WriteJSON(&WebSocketCommandMessage{ + Type: "start", + Stream: "stream1", + }) + var received testPayload + err = c.ReadJSON(&received) + assert.NoError(err) + assert.Equal(int64(1), received.BatchNumber) + assert.Equal("Don't Panic!", received.Message) + + c.WriteJSON(&WebSocketCommandMessage{ + Type: "error", + Message: "Panic!", + Stream: "stream1", + // Note we drive fallback "last batch" processing here, as we omit a batch number + }) + <-roundTripComplete } func TestConnectStreamIsolation(t *testing.T) { @@ -92,6 +245,7 @@ func TestConnectStreamIsolation(t *testing.T) { w, ts := newTestWebSocketServer() defer ts.Close() + defer w.Close() u, err := url.Parse(ts.URL) u.Scheme = "ws" @@ -111,32 +265,44 @@ func TestConnectStreamIsolation(t *testing.T) { Stream: "stream2", }) - s1, _, r1 := w.GetChannels("stream1") - s2, _, r2 := w.GetChannels("stream2") - - s1 <- "Hello Number 1" - s2 <- "Hello Number 2" + roundTripComplete1 := make(chan struct{}) + go func() { + cm, err := w.RoundTrip(w.ctx, "stream1", &testPayload{Message: "Hello Number 1"}) + assert.NoError(err) + assert.Equal("stream1 message", cm.Message) + defer close(roundTripComplete1) + }() + roundTripComplete2 := make(chan struct{}) + go func() { + cm, err := w.RoundTrip(w.ctx, "stream2", &testPayload{Message: "Hello Number 2"}) + assert.NoError(err) + assert.Equal("stream2 message", cm.Message) + defer close(roundTripComplete2) + }() - var val string - c1.ReadJSON(&val) - assert.Equal("Hello Number 1", val) + var received testPayload + err = c1.ReadJSON(&received) + assert.NoError(err) + assert.Equal(int64(1), received.BatchNumber) + assert.Equal("Hello Number 1", received.Message) c1.WriteJSON(&WebSocketCommandMessage{ - Type: "ack", - Stream: "stream1", + Type: "ack", + Stream: "stream1", + Message: "stream1 message", }) - msgOrErr := <-r1 - assert.NoError(msgOrErr.Err) - c2.ReadJSON(&val) - assert.Equal("Hello Number 2", val) + err = c2.ReadJSON(&received) + assert.NoError(err) + assert.Equal(int64(1), received.BatchNumber) + assert.Equal("Hello Number 2", received.Message) c2.WriteJSON(&WebSocketCommandMessage{ - Type: "ack", - Stream: "stream2", + Type: "ack", + Stream: "stream2", + Message: "stream2 message", }) - msgOrErr = <-r2 - assert.NoError(msgOrErr.Err) - w.Close() + <-roundTripComplete1 + <-roundTripComplete2 } @@ -145,6 +311,7 @@ func TestConnectAbandonRequest(t *testing.T) { w, ts := newTestWebSocketServer() defer ts.Close() + defer w.Close() u, err := url.Parse(ts.URL) u.Scheme = "ws" @@ -153,108 +320,72 @@ func TestConnectAbandonRequest(t *testing.T) { assert.NoError(err) c.WriteJSON(&WebSocketCommandMessage{ - Type: "start", + Type: "start", + Stream: "stream1", }) - _, _, r := w.GetChannels("") - wg := sync.WaitGroup{} - wg.Add(1) + // Wait in this test until we've locked in the start + err = w.waitStreamConnections(w.ctx, "stream1", func(ss *streamState) error { return nil }) + assert.NoError(err) + + roundTripComplete := make(chan struct{}) go func() { - select { - case <-r: - break - } - wg.Done() + _, err := w.RoundTrip(w.ctx, "stream1", &testPayload{Message: "Hello World"}) + assert.Regexp("FF00228", err) + defer close(roundTripComplete) }() // Close the client while we've got an active read stream c.Close() - // We whould find the read stream closes out - wg.Wait() - w.Close() - + // We should find the read stream closes out + <-roundTripComplete } -func TestSpuriousAckProcessing(t *testing.T) { +func TestConnectBadWebsocketHandshake(t *testing.T) { assert := assert.New(t) w, ts := newTestWebSocketServer() defer ts.Close() - w.processingTimeout = 1 * time.Millisecond - u, err := url.Parse(ts.URL) - u.Scheme = "ws" + u, _ := url.Parse(ts.URL) u.Path = "/ws" - c, _, err := ws.DefaultDialer.Dial(u.String(), nil) - assert.NoError(err) - - // Drop depth to 1 for spurious ack processing - stream := w.getStream("mystream") - stream.receiverChannel = make(chan *WebSocketCommandMessageOrError, 1) - c.WriteJSON(&WebSocketCommandMessage{ - Type: "ack", - Stream: "mystream", - }) - c.WriteJSON(&WebSocketCommandMessage{ - Type: "ack", - Stream: "mystream", - }) + res, err := http.Get(u.String()) + assert.NoError(err) + assert.Equal(400, res.StatusCode) - for len(w.connections) > 0 { - time.Sleep(1 * time.Millisecond) - } w.Close() + } -func TestSpuriousNackProcessing(t *testing.T) { +func TestBroadcastStartWithoutConnections(t *testing.T) { assert := assert.New(t) w, ts := newTestWebSocketServer() defer ts.Close() - w.processingTimeout = 1 * time.Millisecond + defer w.Close() - u, err := url.Parse(ts.URL) + u, _ := url.Parse(ts.URL) u.Scheme = "ws" u.Path = "/ws" - c, _, err := ws.DefaultDialer.Dial(u.String(), nil) - assert.NoError(err) + stream := "banana" - // Drop depth to 1 for spurious ack processing - stream := w.getStream("mystream") - stream.receiverChannel = make(chan *WebSocketCommandMessageOrError, 1) + go w.Broadcast(w.ctx, stream, "Hello World") - c.WriteJSON(&WebSocketCommandMessage{ - Type: "ack", - Stream: "mystream", - }) - c.WriteJSON(&WebSocketCommandMessage{ - Type: "error", - Stream: "mystream", + c1, _, err := ws.DefaultDialer.Dial(u.String(), nil) + assert.NoError(err) + c1.WriteJSON(&WebSocketCommandMessage{ + Type: "start", + Stream: stream, }) - for len(w.connections) > 0 { - time.Sleep(1 * time.Millisecond) - } - w.Close() -} - -func TestConnectBadWebsocketHandshake(t *testing.T) { - assert := assert.New(t) - - w, ts := newTestWebSocketServer() - defer ts.Close() - - u, _ := url.Parse(ts.URL) - u.Path = "/ws" - - res, err := http.Get(u.String()) + c2, _, err := ws.DefaultDialer.Dial(u.String(), nil) assert.NoError(err) - assert.Equal(400, res.StatusCode) - - w.Close() - + c2.WriteJSON(&WebSocketCommandMessage{ + Type: "start", + Stream: stream, + }) } func TestBroadcast(t *testing.T) { @@ -262,134 +393,181 @@ func TestBroadcast(t *testing.T) { w, ts := newTestWebSocketServer() defer ts.Close() + defer w.Close() u, _ := url.Parse(ts.URL) u.Scheme = "ws" u.Path = "/ws" stream := "banana" - c, _, err := ws.DefaultDialer.Dial(u.String(), nil) + + c1, _, err := ws.DefaultDialer.Dial(u.String(), nil) assert.NoError(err) + c1.WriteJSON(&WebSocketCommandMessage{ + Type: "start", + Stream: stream, + }) - c.WriteJSON(&WebSocketCommandMessage{ + c2, _, err := ws.DefaultDialer.Dial(u.String(), nil) + assert.NoError(err) + c2.WriteJSON(&WebSocketCommandMessage{ Type: "start", Stream: stream, }) - // Wait until the client has subscribed to the stream before proceeding - for len(w.streamMap[stream]) == 0 { + // Wait until the clients have subscribed to the stream before proceeding + count := 0 + for count < 2 { time.Sleep(10 * time.Millisecond) + err = w.waitStreamConnections(w.ctx, stream, func(ss *streamState) error { + count = len(ss.conns) + return nil + }) + assert.NoError(err) } - _, b, _ := w.GetChannels(stream) - b <- "Hello World" - var val string - c.ReadJSON(&val) - assert.Equal("Hello World", val) - b <- "Hello World Again" + w.Broadcast(w.ctx, stream, "Hello World") + c1.ReadJSON(&val) + assert.Equal("Hello World", val) + c2.ReadJSON(&val) + assert.Equal("Hello World", val) - c.ReadJSON(&val) + w.Broadcast(w.ctx, stream, "Hello World Again") + c1.ReadJSON(&val) + assert.Equal("Hello World Again", val) + c2.ReadJSON(&val) assert.Equal("Hello World Again", val) - w.Close() } -func TestBroadcastDefaultStream(t *testing.T) { +func TestActionsAfterClose(t *testing.T) { assert := assert.New(t) w, ts := newTestWebSocketServer() defer ts.Close() + defer w.Close() u, _ := url.Parse(ts.URL) u.Scheme = "ws" u.Path = "/ws" - stream := "" c, _, err := ws.DefaultDialer.Dial(u.String(), nil) assert.NoError(err) c.WriteJSON(&WebSocketCommandMessage{ - Type: "start", + Type: "start", + Stream: "stream1", }) - // Wait until the client has subscribed to the stream before proceeding - for len(w.streamMap[stream]) == 0 { - time.Sleep(10 * time.Millisecond) - } - - _, b, _ := w.GetChannels(stream) - b <- "Hello World" + closedCtx, closeCtx := context.WithCancel(context.Background()) + closeCtx() + err = w.waitStreamConnections(closedCtx, "stream1", func(ss *streamState) error { return nil }) + assert.Regexp("FF00154", err) - var val string - c.ReadJSON(&val) - assert.Equal("Hello World", val) + var conn *webSocketConnection + err = w.waitStreamConnections(w.ctx, "stream1", func(ss *streamState) error { + conn = ss.conns[0] + return nil + }) + assert.NoError(err) - b <- "Hello World Again" + c.Close() + <-conn.closing + + // Send after close + conn.send = make(chan interface{}) + conn.closing = make(chan struct{}) + go func() { conn.send <- "anything" }() + conn.sender() + + // Broadcast after close + close(conn.closing) + w.streamMap = map[string]*streamState{ + "stream1": { + conns: []*webSocketConnection{ + conn, + }, + }, + } + w.Broadcast(w.ctx, "stream1", "test1") - c.ReadJSON(&val) - assert.Equal("Hello World Again", val) + // RoundTrip after close + _, err = w.RoundTrip(w.ctx, "stream1", &testPayload{}) + assert.Regexp("FF00228", err) - w.Close() + // round trip not in fight + w.completeRoundTrip("stream1", &WebSocketCommandMessage{}, nil) } -func TestRecvNotOk(t *testing.T) { +func TestRoundTripClientAckTimeout(t *testing.T) { + logrus.SetLevel(logrus.TraceLevel) + assert := assert.New(t) w, ts := newTestWebSocketServer() defer ts.Close() + defer w.Close() + w.conf.AckTimeout = 1 * time.Microsecond - u, _ := url.Parse(ts.URL) + u, err := url.Parse(ts.URL) u.Scheme = "ws" u.Path = "/ws" - stream := "" c, _, err := ws.DefaultDialer.Dial(u.String(), nil) assert.NoError(err) + roundTripComplete := make(chan struct{}) + go func() { + _, err := w.RoundTrip(w.ctx, "stream1", &testPayload{Message: "anybody there"}) + assert.Regexp("FF00244", err) + defer close(roundTripComplete) + }() + c.WriteJSON(&WebSocketCommandMessage{ - Type: "start", + Type: "start", + Stream: "stream1", }) - // Wait until the client has subscribed to the stream before proceeding - for len(w.streamMap[stream]) == 0 { - time.Sleep(10 * time.Millisecond) - } + var received testPayload + err = c.ReadJSON(&received) + assert.NoError(err) + assert.Equal(int64(1), received.BatchNumber) + assert.Equal("anybody there", received.Message) - _, b, _ := w.GetChannels(stream) - close(b) - w.Close() + <-roundTripComplete } -func TestListenStreamClosing(t *testing.T) { +func TestRoundTripContextCancelled(t *testing.T) { + logrus.SetLevel(logrus.TraceLevel) + + assert := assert.New(t) w, ts := newTestWebSocketServer() defer ts.Close() - w.getStream("test") + defer w.Close() + w.conf.AckTimeout = 1 * time.Microsecond - c := &webSocketConnection{ - server: w, - streams: make(map[string]*webSocketStream), - closing: make(chan struct{}), - newStream: make(chan bool), - } - close(c.closing) - c.startStream(&webSocketStream{ - streamName: "test", + u, err := url.Parse(ts.URL) + u.Scheme = "ws" + u.Path = "/ws" + c, _, err := ws.DefaultDialer.Dial(u.String(), nil) + assert.NoError(err) + defer c.Close() + + c.WriteJSON(&WebSocketCommandMessage{ + Type: "start", + Stream: "stream1", }) -} -func TestBroadcastClosing(t *testing.T) { + err = w.waitStreamConnections(w.ctx, "stream1", func(ss *streamState) error { return nil }) + assert.NoError(err) - w, ts := newTestWebSocketServer() - defer ts.Close() - w.getStream("test") + // close the context before round trip + canceledCtx, cancelCtx := context.WithCancel(context.Background()) + cancelCtx() - c := &webSocketConnection{ - server: w, - streams: make(map[string]*webSocketStream), - closing: make(chan struct{}), - newStream: make(chan bool), - } - close(c.closing) - // Check this doesn't block - c.server.broadcastToConnections([]*webSocketConnection{c}, "anything") + // Block the send + w.streamMap["stream1"].conns[0].send = make(chan interface{}) + + _, err = w.RoundTrip(canceledCtx, "stream1", &testPayload{Message: "anybody there"}) + assert.Regexp("FF00154", err) }