From 4affa8a51140fbe3edbf1c0c67a6387339c152a0 Mon Sep 17 00:00:00 2001 From: Tudor Malene Date: Wed, 3 Apr 2024 12:58:02 +0100 Subject: [PATCH] New heads subscription (#1861) * new heads subscription * fix * add scaffolding for cache invalidation * fix * fix --- go/common/host/host.go | 12 +- go/common/host/services.go | 1 + go/common/subscription/new_heads_manager.go | 116 ++++++++++++++++++ go/common/subscription/utils.go | 69 +++++++++++ go/host/container/host_container.go | 5 +- go/host/host.go | 19 ++- go/host/rpc/clientapi/client_api_filter.go | 74 +++++------ go/rpc/client.go | 19 +-- integration/obscurogateway/tengateway_test.go | 32 +++++ tools/walletextension/rpcapi/filter_api.go | 90 ++++---------- tools/walletextension/rpcapi/utils.go | 8 +- .../rpcapi/wallet_extension.go | 45 ++++++- .../walletextension_container.go | 39 +++--- 13 files changed, 374 insertions(+), 155 deletions(-) create mode 100644 go/common/subscription/new_heads_manager.go create mode 100644 go/common/subscription/utils.go diff --git a/go/common/host/host.go b/go/common/host/host.go index 154dc4f127..08799ab104 100644 --- a/go/common/host/host.go +++ b/go/common/host/host.go @@ -19,10 +19,10 @@ type Host interface { Start() error // SubmitAndBroadcastTx submits an encrypted transaction to the enclave, and broadcasts it to the other hosts on the network. SubmitAndBroadcastTx(encryptedParams common.EncryptedParamsSendRawTx) (*responses.RawTx, error) - // Subscribe feeds logs matching the encrypted log subscription to the matchedLogs channel. - Subscribe(id rpc.ID, encryptedLogSubscription common.EncryptedParamsLogSubscription, matchedLogs chan []byte) error - // Unsubscribe terminates a log subscription between the host and the enclave. - Unsubscribe(id rpc.ID) + // SubscribeLogs feeds logs matching the encrypted log subscription to the matchedLogs channel. + SubscribeLogs(id rpc.ID, encryptedLogSubscription common.EncryptedParamsLogSubscription, matchedLogs chan []byte) error + // UnsubscribeLogs terminates a log subscription between the host and the enclave. + UnsubscribeLogs(id rpc.ID) // Stop gracefully stops the host execution. Stop() error @@ -31,6 +31,10 @@ type Host interface { // ObscuroConfig returns the info of the Obscuro network ObscuroConfig() (*common.ObscuroNetworkInfo, error) + + // NewHeadsChan returns live batch headers + // Note - do not use directly. This is meant only for the NewHeadsManager, which multiplexes the headers + NewHeadsChan() chan *common.BatchHeader } type BlockStream struct { diff --git a/go/common/host/services.go b/go/common/host/services.go index d40b236025..3c1d532704 100644 --- a/go/common/host/services.go +++ b/go/common/host/services.go @@ -20,6 +20,7 @@ const ( L2BatchRepositoryName = "l2-batch-repo" EnclaveServiceName = "enclaves" LogSubscriptionServiceName = "log-subs" + FilterAPIServiceName = "filter-api" ) // The host has a number of services that encapsulate the various responsibilities of the host. diff --git a/go/common/subscription/new_heads_manager.go b/go/common/subscription/new_heads_manager.go new file mode 100644 index 0000000000..7c82708147 --- /dev/null +++ b/go/common/subscription/new_heads_manager.go @@ -0,0 +1,116 @@ +package subscription + +import ( + "math/big" + "sync" + "sync/atomic" + + gethcommon "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + + gethlog "github.com/ethereum/go-ethereum/log" + "github.com/ten-protocol/go-ten/go/common" + "github.com/ten-protocol/go-ten/go/common/host" + "github.com/ten-protocol/go-ten/go/common/log" + "github.com/ten-protocol/go-ten/lib/gethfork/rpc" +) + +// NewHeadsService multiplexes new batch header messages from an input channel into multiple subscribers +// also handles unsubscribe +// Note: this is a service which must be Started and Stopped +type NewHeadsService struct { + inputCh chan *common.BatchHeader + convertToEthHeader bool + notifiersMutex *sync.RWMutex + newHeadNotifiers map[rpc.ID]*rpc.Notifier + onMessage func(*common.BatchHeader) error + stopped *atomic.Bool + logger gethlog.Logger +} + +func NewNewHeadsService(inputCh chan *common.BatchHeader, convertToEthHeader bool, logger gethlog.Logger, onMessage func(*common.BatchHeader) error) *NewHeadsService { + return &NewHeadsService{ + inputCh: inputCh, + convertToEthHeader: convertToEthHeader, + onMessage: onMessage, + logger: logger, + stopped: &atomic.Bool{}, + newHeadNotifiers: make(map[rpc.ID]*rpc.Notifier), + notifiersMutex: &sync.RWMutex{}, + } +} + +func (nhs *NewHeadsService) Start() error { + go ForwardFromChannels([]chan *common.BatchHeader{nhs.inputCh}, nhs.stopped, func(head *common.BatchHeader) error { + nhs.notifiersMutex.RLock() + defer nhs.notifiersMutex.RUnlock() + + if nhs.onMessage != nil { + err := nhs.onMessage(head) + if err != nil { + nhs.logger.Info("failed invoking onMessage callback.", log.ErrKey, err) + } + } + + var msg any = head + if nhs.convertToEthHeader { + msg = convertBatchHeader(head) + } + + // for each new head, notify all registered subscriptions + for id, notifier := range nhs.newHeadNotifiers { + if nhs.stopped.Load() { + return nil + } + err := notifier.Notify(id, msg) + if err != nil { + // on error, remove the notification + nhs.logger.Info("failed to notify newHead subscription", log.ErrKey, err, log.SubIDKey, id) + nhs.notifiersMutex.Lock() + delete(nhs.newHeadNotifiers, id) + nhs.notifiersMutex.Unlock() + } + } + return nil + }) + return nil +} + +func (nhs *NewHeadsService) RegisterNotifier(notifier *rpc.Notifier, subscription *rpc.Subscription) { + nhs.notifiersMutex.Lock() + defer nhs.notifiersMutex.Unlock() + nhs.newHeadNotifiers[subscription.ID] = notifier + + go HandleUnsubscribe(subscription, nil, func() { + nhs.notifiersMutex.Lock() + defer nhs.notifiersMutex.Unlock() + delete(nhs.newHeadNotifiers, subscription.ID) + }) +} + +func (nhs *NewHeadsService) Stop() error { + nhs.stopped.Store(true) + return nil +} + +func (nhs *NewHeadsService) HealthStatus() host.HealthStatus { + return &host.BasicErrHealthStatus{} +} + +func convertBatchHeader(head *common.BatchHeader) *types.Header { + return &types.Header{ + ParentHash: head.ParentHash, + UncleHash: gethcommon.Hash{}, + Root: head.Root, + TxHash: head.TxHash, + ReceiptHash: head.ReceiptHash, + Bloom: types.Bloom{}, + Difficulty: big.NewInt(0), + Number: head.Number, + GasLimit: head.GasLimit, + GasUsed: head.GasUsed, + Time: head.Time, + Extra: make([]byte, 0), + BaseFee: head.BaseFee, + } +} diff --git a/go/common/subscription/utils.go b/go/common/subscription/utils.go new file mode 100644 index 0000000000..b67c352764 --- /dev/null +++ b/go/common/subscription/utils.go @@ -0,0 +1,69 @@ +package subscription + +import ( + "reflect" + "sync/atomic" + "time" + + "github.com/ten-protocol/go-ten/lib/gethfork/rpc" +) + +// ForwardFromChannels - reads messages from the input channels, and calls the `onMessage` callback. +// Exits when the unsubscribed flag is true. +// Must be called as a go routine! +func ForwardFromChannels[R any](inputChannels []chan R, unsubscribed *atomic.Bool, onMessage func(R) error) { + inputCases := make([]reflect.SelectCase, len(inputChannels)+1) + + // create a ticker to handle cleanup, check the "unsubscribed" flag and exit the goroutine + inputCases[0] = reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(time.NewTicker(2 * time.Second).C), + } + + // create a select "case" for each input channel + for i, ch := range inputChannels { + inputCases[i+1] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)} + } + + unclosedInputChannels := len(inputCases) + for unclosedInputChannels > 0 { + chosen, value, ok := reflect.Select(inputCases) + if !ok { + // The chosen channel has been closed, so zero out the channel to disable the case + inputCases[chosen].Chan = reflect.ValueOf(nil) + unclosedInputChannels-- + continue + } + + if unsubscribed.Load() { + return + } + + switch v := value.Interface().(type) { + case time.Time: + // exit the loop to avoid a goroutine leak + if unsubscribed.Load() { + return + } + case R: + err := onMessage(v) + if err != nil { + // todo - log + return + } + default: + // ignore unexpected element + continue + } + } +} + +// HandleUnsubscribe - when the client calls "unsubscribe" or the subscription times out, it calls `onSub` +// Must be called as a go routine! +func HandleUnsubscribe(connectionSub *rpc.Subscription, unsubscribed *atomic.Bool, onUnsub func()) { + <-connectionSub.Err() + if unsubscribed != nil { + unsubscribed.Store(true) + } + onUnsub() +} diff --git a/go/host/container/host_container.go b/go/host/container/host_container.go index c4b33081ba..2dd8367fe7 100644 --- a/go/host/container/host_container.go +++ b/go/host/container/host_container.go @@ -169,6 +169,7 @@ func NewHostContainer(cfg *config.HostConfig, services *host.ServicesRegistry, p } if cfg.HasClientRPCHTTP || cfg.HasClientRPCWebsockets { + filterAPI := clientapi.NewFilterAPI(h, logger) rpcServer.RegisterAPIs([]rpc.API{ { Namespace: APINamespaceObscuro, @@ -192,7 +193,7 @@ func NewHostContainer(cfg *config.HostConfig, services *host.ServicesRegistry, p }, { Namespace: APINamespaceEth, - Service: clientapi.NewFilterAPI(h, logger), + Service: filterAPI, }, { Namespace: APINamespaceScan, @@ -208,7 +209,7 @@ func NewHostContainer(cfg *config.HostConfig, services *host.ServicesRegistry, p }, }) } + services.RegisterService(hostcommon.FilterAPIServiceName, filterAPI.NewHeadsService) } - return hostContainer } diff --git a/go/host/host.go b/go/host/host.go index 6df417d983..4afbaaf0d7 100644 --- a/go/host/host.go +++ b/go/host/host.go @@ -47,6 +47,15 @@ type host struct { // l2MessageBusAddress is fetched from the enclave but cache it here because it never changes l2MessageBusAddress *gethcommon.Address + newHeads chan *common.BatchHeader +} + +type batchListener struct { + newHeads chan *common.BatchHeader +} + +func (bl batchListener) HandleBatch(batch *common.ExtBatch) { + bl.newHeads <- batch.Header } func NewHost(config *config.HostConfig, hostServices *ServicesRegistry, p2p hostcommon.P2PHostService, ethClient ethadapter.EthClient, l1Repo hostcommon.L1RepoService, enclaveClients []common.Enclave, ethWallet wallet.Wallet, mgmtContractLib mgmtcontractlib.MgmtContractLib, logger gethlog.Logger, regMetrics gethmetrics.Registry) hostcommon.Host { @@ -70,6 +79,7 @@ func NewHost(config *config.HostConfig, hostServices *ServicesRegistry, p2p host metricRegistry: regMetrics, stopControl: stopcontrol.New(), + newHeads: make(chan *common.BatchHeader), } enclGuardians := make([]*enclave.Guardian, 0, len(enclaveClients)) @@ -89,6 +99,7 @@ func NewHost(config *config.HostConfig, hostServices *ServicesRegistry, p2p host l2Repo := l2.NewBatchRepository(config, hostServices, database, logger) subsService := events.NewLogEventManager(hostServices, logger) + l2Repo.Subscribe(batchListener{newHeads: host.newHeads}) hostServices.RegisterService(hostcommon.P2PName, p2p) hostServices.RegisterService(hostcommon.L1BlockRepositoryName, l1Repo) maxWaitForL1Receipt := 6 * config.L1BlockTime // wait ~10 blocks to see if tx gets published before retrying @@ -158,14 +169,14 @@ func (h *host) SubmitAndBroadcastTx(encryptedParams common.EncryptedParamsSendRa return h.services.Enclaves().SubmitAndBroadcastTx(encryptedParams) } -func (h *host) Subscribe(id rpc.ID, encryptedLogSubscription common.EncryptedParamsLogSubscription, matchedLogsCh chan []byte) error { +func (h *host) SubscribeLogs(id rpc.ID, encryptedLogSubscription common.EncryptedParamsLogSubscription, matchedLogsCh chan []byte) error { if h.stopControl.IsStopping() { return responses.ToInternalError(fmt.Errorf("requested Subscribe with the host stopping")) } return h.services.LogSubs().Subscribe(id, encryptedLogSubscription, matchedLogsCh) } -func (h *host) Unsubscribe(id rpc.ID) { +func (h *host) UnsubscribeLogs(id rpc.ID) { if h.stopControl.IsStopping() { h.logger.Debug("requested Subscribe with the host stopping") } @@ -235,6 +246,10 @@ func (h *host) ObscuroConfig() (*common.ObscuroNetworkInfo, error) { }, nil } +func (h *host) NewHeadsChan() chan *common.BatchHeader { + return h.newHeads +} + // Checks the host config is valid. func (h *host) validateConfig() { if h.config.IsGenesis && h.config.NodeType != common.Sequencer { diff --git a/go/host/rpc/clientapi/client_api_filter.go b/go/host/rpc/clientapi/client_api_filter.go index 2c7ecfb626..9287c36a14 100644 --- a/go/host/rpc/clientapi/client_api_filter.go +++ b/go/host/rpc/clientapi/client_api_filter.go @@ -4,9 +4,9 @@ import ( "context" "fmt" "sync/atomic" - "time" "github.com/ten-protocol/go-ten/go/common/host" + subscriptioncommon "github.com/ten-protocol/go-ten/go/common/subscription" "github.com/ten-protocol/go-ten/go/responses" gethlog "github.com/ethereum/go-ethereum/log" @@ -19,18 +19,30 @@ import ( // FilterAPI exposes a subset of Geth's PublicFilterAPI operations. type FilterAPI struct { - host host.Host - logger gethlog.Logger + host host.Host + logger gethlog.Logger + NewHeadsService *subscriptioncommon.NewHeadsService } func NewFilterAPI(host host.Host, logger gethlog.Logger) *FilterAPI { return &FilterAPI{ - host: host, - logger: logger, + host: host, + logger: logger, + NewHeadsService: subscriptioncommon.NewNewHeadsService(host.NewHeadsChan(), false, logger, nil), } } -// Logs returns a log subscription. +func (api *FilterAPI) NewHeads(ctx context.Context) (*rpc.Subscription, error) { + notifier, supported := rpc.NotifierFromContext(ctx) + if !supported { + return nil, fmt.Errorf("creation of subscriptions is not supported") + } + subscription := notifier.CreateSubscription() + api.NewHeadsService.RegisterNotifier(notifier, subscription) + return subscription, nil +} + +// Logs exposes the "logs" rpc endpoint. func (api *FilterAPI) Logs(ctx context.Context, encryptedParams common.EncryptedParamsLogSubscription) (*rpc.Subscription, error) { notifier, supported := rpc.NotifierFromContext(ctx) if !supported { @@ -39,7 +51,7 @@ func (api *FilterAPI) Logs(ctx context.Context, encryptedParams common.Encrypted subscription := notifier.CreateSubscription() logsFromSubscription := make(chan []byte) - err := api.host.Subscribe(subscription.ID, encryptedParams, logsFromSubscription) + err := api.host.SubscribeLogs(subscription.ID, encryptedParams, logsFromSubscription) if err != nil { return nil, fmt.Errorf("could not subscribe for logs. Cause: %w", err) } @@ -51,49 +63,21 @@ func (api *FilterAPI) Logs(ctx context.Context, encryptedParams common.Encrypted SubID: subscription.ID, }) if err != nil { - api.host.Unsubscribe(subscription.ID) + api.host.UnsubscribeLogs(subscription.ID) return nil, fmt.Errorf("could not send subscription ID to client on subscription %s", subscription.ID) } var unsubscribed atomic.Bool - - go func() { - // to avoid unsubscribe deadlocks we have a 10 second delay between the unsubscribe command - // and the moment we stop listening for messages - for { - select { - case encryptedLog, ok := <-logsFromSubscription: - if !ok { - api.logger.Info("subscription channel closed", log.SubIDKey, subscription.ID) - return - } - if unsubscribed.Load() { - api.logger.Debug("subscription unsubscribed", log.SubIDKey, subscription.ID) - return - } - idAndEncLog := common.IDAndEncLog{ - SubID: subscription.ID, - EncLog: encryptedLog, - } - err = notifier.Notify(subscription.ID, idAndEncLog) - if err != nil { - api.logger.Error("could not send encrypted log to client on subscription ", log.SubIDKey, subscription.ID) - } - case <-time.After(10 * time.Second): - if unsubscribed.Load() { - return - } - } + go subscriptioncommon.ForwardFromChannels([]chan []byte{logsFromSubscription}, &unsubscribed, func(elem []byte) error { + msg := &common.IDAndEncLog{ + SubID: subscription.ID, + EncLog: elem, } - }() - - // unsubscribe commands are handled in a different go-routine to avoid deadlocking with the log processing - go func() { - <-subscription.Err() - api.host.Unsubscribe(subscription.ID) - unsubscribed.Store(true) - }() - + return notifier.Notify(subscription.ID, msg) + }) + go subscriptioncommon.HandleUnsubscribe(subscription, &unsubscribed, func() { + api.host.UnsubscribeLogs(subscription.ID) + }) return subscription, nil } diff --git a/go/rpc/client.go b/go/rpc/client.go index b1d05c6f6e..fbb2408702 100644 --- a/go/rpc/client.go +++ b/go/rpc/client.go @@ -26,15 +26,16 @@ const ( Health = "obscuro_health" Config = "obscuro_config" - GetBlockHeaderByHash = "tenscan_getBlockHeaderByHash" - GetBatch = "tenscan_getBatch" - GetBatchForTx = "tenscan_getBatchForTx" - GetLatestTxs = "tenscan_getLatestTransactions" - GetTotalTxs = "tenscan_getTotalTransactions" - Attestation = "tenscan_attestation" - StopHost = "test_stopHost" - SubscribeNamespace = "eth" - SubscriptionTypeLogs = "logs" + GetBlockHeaderByHash = "tenscan_getBlockHeaderByHash" + GetBatch = "tenscan_getBatch" + GetBatchForTx = "tenscan_getBatchForTx" + GetLatestTxs = "tenscan_getLatestTransactions" + GetTotalTxs = "tenscan_getTotalTransactions" + Attestation = "tenscan_attestation" + StopHost = "test_stopHost" + SubscribeNamespace = "eth" + SubscriptionTypeLogs = "logs" + SubscriptionTypeNewHeads = "newHeads" // GetL1RollupHeaderByHash = "scan_getL1RollupHeaderByHash" // GetActiveNodeCount = "scan_getActiveNodeCount" diff --git a/integration/obscurogateway/tengateway_test.go b/integration/obscurogateway/tengateway_test.go index c8846dedfc..56f35b8a53 100644 --- a/integration/obscurogateway/tengateway_test.go +++ b/integration/obscurogateway/tengateway_test.go @@ -103,6 +103,7 @@ func TestTenGateway(t *testing.T) { //"testAreTxsMinted": testAreTxsMinted, this breaks the other tests bc, enable once concurrency issues are fixed "testErrorHandling": testErrorHandling, "testMultipleAccountsSubscription": testMultipleAccountsSubscription, + "testNewHeadsSubscription": testNewHeadsSubscription, "testErrorsRevertedArePassed": testErrorsRevertedArePassed, "testUnsubscribe": testUnsubscribe, "testClosingConnectionWhileSubscribed": testClosingConnectionWhileSubscribed, @@ -123,6 +124,37 @@ func TestTenGateway(t *testing.T) { assert.NoError(t, err) } +func testNewHeadsSubscription(t *testing.T, httpURL, wsURL string, w wallet.Wallet) { + user0, err := NewGatewayUser([]wallet.Wallet{w, datagenerator.RandomWallet(integration.TenChainID)}, httpURL, wsURL) + require.NoError(t, err) + + receivedHeads := make([]*types.Header, 0) + newHeadChan := make(chan *types.Header) + subscription, err := user0.WSClient.SubscribeNewHead(context.Background(), newHeadChan) + require.NoError(t, err) + + // Listen for new heads in a goroutine + go func() { + for { + select { + case err := <-subscription.Err(): + // if err != nil { + testlog.Logger().Info("Error from new head subscription", log2.ErrKey, err) + return + //} + case newHead := <-newHeadChan: + // append logs to be visible from the main thread + receivedHeads = append(receivedHeads, newHead) + } + } + }() + + // sleep for 5 seconds and there should be at least 2 heads received in this interval + time.Sleep(5 * time.Second) + subscription.Unsubscribe() + require.True(t, len(receivedHeads) > 1) +} + func testMultipleAccountsSubscription(t *testing.T, httpURL, wsURL string, w wallet.Wallet) { user0, err := NewGatewayUser([]wallet.Wallet{w, datagenerator.RandomWallet(integration.TenChainID)}, httpURL, wsURL) require.NoError(t, err) diff --git a/tools/walletextension/rpcapi/filter_api.go b/tools/walletextension/rpcapi/filter_api.go index 14492fb4c6..7d7e4f85ce 100644 --- a/tools/walletextension/rpcapi/filter_api.go +++ b/tools/walletextension/rpcapi/filter_api.go @@ -3,11 +3,11 @@ package rpcapi import ( "context" "fmt" - "reflect" "sync/atomic" "time" - pool "github.com/jolestar/go-commons-pool/v2" + subscriptioncommon "github.com/ten-protocol/go-ten/go/common/subscription" + tenrpc "github.com/ten-protocol/go-ten/go/rpc" gethcommon "github.com/ethereum/go-ethereum/common" @@ -24,7 +24,9 @@ type FilterAPI struct { } func NewFilterAPI(we *Services) *FilterAPI { - return &FilterAPI{we: we} + return &FilterAPI{ + we: we, + } } func (api *FilterAPI) NewPendingTransactionFilter(_ *bool) rpc.ID { @@ -41,7 +43,13 @@ func (api *FilterAPI) NewBlockFilter() rpc.ID { } func (api *FilterAPI) NewHeads(ctx context.Context) (*rpc.Subscription, error) { - return nil, rpcNotImplemented + notifier, supported := rpc.NotifierFromContext(ctx) + if !supported { + return nil, fmt.Errorf("creation of subscriptions is not supported") + } + subscription := notifier.CreateSubscription() + api.we.NewHeadsService.RegisterNotifier(notifier, subscription) + return subscription, nil } func (api *FilterAPI) Logs(ctx context.Context, crit common.FilterCriteria) (*rpc.Subscription, error) { @@ -61,15 +69,15 @@ func (api *FilterAPI) Logs(ctx context.Context, crit common.FilterCriteria) (*rp } } + backendWSConnections := make([]*tenrpc.EncRPCClient, 0) inputChannels := make([]chan common.IDAndLog, 0) backendSubscriptions := make([]*rpc.ClientSubscription, 0) - connections := make([]*tenrpc.EncRPCClient, 0) for _, address := range candidateAddresses { rpcWSClient, err := connectWS(user.accounts[*address], api.we.Logger()) if err != nil { return nil, err } - connections = append(connections, rpcWSClient) + backendWSConnections = append(backendWSConnections, rpcWSClient) inCh := make(chan common.IDAndLog) backendSubscription, err := rpcWSClient.Subscribe(ctx, "eth", inCh, "logs", crit) @@ -86,7 +94,7 @@ func (api *FilterAPI) Logs(ctx context.Context, crit common.FilterCriteria) (*rp subscription := subNotifier.CreateSubscription() unsubscribed := atomic.Bool{} - go forwardAndDedupe(inputChannels, backendSubscriptions, subscription, subNotifier, &unsubscribed, func(data common.IDAndLog) *types.Log { + go subscriptioncommon.ForwardFromChannels(inputChannels, &unsubscribed, func(data common.IDAndLog) error { uniqueLogKey := LogKey{ BlockHash: data.Log.BlockHash, TxHash: data.Log.TxHash, @@ -95,12 +103,19 @@ func (api *FilterAPI) Logs(ctx context.Context, crit common.FilterCriteria) (*rp if !dedupeBuffer.Contains(uniqueLogKey) { dedupeBuffer.Push(uniqueLogKey) - return data.Log + return subNotifier.Notify(subscription.ID, data.Log) } return nil }) - go handleUnsubscribe(subscription, backendSubscriptions, connections, api.we.rpcWSConnPool, &unsubscribed) + go subscriptioncommon.HandleUnsubscribe(subscription, &unsubscribed, func() { + for _, backendSub := range backendSubscriptions { + backendSub.Unsubscribe() + } + for _, connection := range backendWSConnections { + _ = returnConn(api.we.rpcWSConnPool, connection.BackingClient()) + } + }) return subscription, err } @@ -137,63 +152,6 @@ func searchForAddressInFilterCriteria(filterCriteria common.FilterCriteria, poss return result } -// forwardAndDedupe - reads messages from the input channels, and forwards them to the notifier only if they are new -func forwardAndDedupe[R any, T any](inputChannels []chan R, _ []*rpc.ClientSubscription, outSub *rpc.Subscription, notifier *rpc.Notifier, unsubscribed *atomic.Bool, toForward func(elem R) *T) { - inputCases := make([]reflect.SelectCase, len(inputChannels)+1) - - // create a ticker to handle cleanup - inputCases[0] = reflect.SelectCase{ - Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(time.NewTicker(10 * time.Second).C), - } - - // create a select "case" for each input channel - for i, ch := range inputChannels { - inputCases[i+1] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch)} - } - - unclosedInputChannels := len(inputCases) - for unclosedInputChannels > 0 { - chosen, value, ok := reflect.Select(inputCases) - if !ok { - // The chosen channel has been closed, so zero out the channel to disable the case - inputCases[chosen].Chan = reflect.ValueOf(nil) - unclosedInputChannels-- - continue - } - - switch v := value.Interface().(type) { - case time.Time: - // exit the loop to avoid a goroutine loop - if unsubscribed.Load() { - return - } - case R: - valueToSubmit := toForward(v) - if valueToSubmit != nil { - err := notifier.Notify(outSub.ID, *valueToSubmit) - if err != nil { - return - } - } - default: - // unexpected element received - continue - } - } -} - -func handleUnsubscribe(connectionSub *rpc.Subscription, backendSubscriptions []*rpc.ClientSubscription, connections []*tenrpc.EncRPCClient, p *pool.ObjectPool, unsubscribed *atomic.Bool) { - <-connectionSub.Err() - unsubscribed.Store(true) - for _, backendSub := range backendSubscriptions { - backendSub.Unsubscribe() - } - for _, connection := range connections { - _ = returnConn(p, connection.BackingClient()) - } -} - func (api *FilterAPI) NewFilter(crit common.FilterCriteria) (rpc.ID, error) { return rpc.NewID(), rpcNotImplemented } diff --git a/tools/walletextension/rpcapi/utils.go b/tools/walletextension/rpcapi/utils.go index e6e1317968..009c31f88a 100644 --- a/tools/walletextension/rpcapi/utils.go +++ b/tools/walletextension/rpcapi/utils.go @@ -20,7 +20,7 @@ import ( "github.com/ten-protocol/go-ten/tools/walletextension/cache" - "github.com/ethereum/go-ethereum/common" + gethcommon "github.com/ethereum/go-ethereum/common" ) const ( @@ -36,8 +36,8 @@ const ( var rpcNotImplemented = fmt.Errorf("rpc endpoint not implemented") type ExecCfg struct { - account *common.Address - computeFromCallback func(user *GWUser) *common.Address + account *gethcommon.Address + computeFromCallback func(user *GWUser) *gethcommon.Address tryAll bool tryUntilAuthorised bool adjustArgs func(acct *GWAccount) []any @@ -165,7 +165,7 @@ func extractUserID(ctx context.Context, _ *Services) ([]byte, error) { if !ok { return nil, fmt.Errorf("invalid userid: %s", ctx.Value(rpc.GWTokenKey{})) } - userID := common.FromHex(token) + userID := gethcommon.FromHex(token) if len(userID) != viewingkey.UserIDLength { return nil, fmt.Errorf("invalid userid: %s", token) } diff --git a/tools/walletextension/rpcapi/wallet_extension.go b/tools/walletextension/rpcapi/wallet_extension.go index 6e271a86dd..5d9e6abf4a 100644 --- a/tools/walletextension/rpcapi/wallet_extension.go +++ b/tools/walletextension/rpcapi/wallet_extension.go @@ -7,6 +7,11 @@ import ( "fmt" "time" + subscriptioncommon "github.com/ten-protocol/go-ten/go/common/subscription" + + common2 "github.com/ten-protocol/go-ten/go/common" + "github.com/ten-protocol/go-ten/go/rpc" + "github.com/ten-protocol/go-ten/go/obsclient" pool "github.com/jolestar/go-commons-pool/v2" @@ -37,9 +42,15 @@ type Services struct { version string Cache cache.Cache // the OG maintains a connection pool of rpc connections to underlying nodes - rpcHTTPConnPool *pool.ObjectPool - rpcWSConnPool *pool.ObjectPool - Config *common.Config + rpcHTTPConnPool *pool.ObjectPool + rpcWSConnPool *pool.ObjectPool + Config *common.Config + backendNewHeadsSubscription *gethrpc.ClientSubscription + NewHeadsService *subscriptioncommon.NewHeadsService +} + +type NewHeadNotifier interface { + onNewHead(header *common2.BatchHeader) } func NewServices(hostAddrHTTP string, hostAddrWS string, storage storage.Storage, stopControl *stopcontrol.StopControl, version string, logger gethlog.Logger, config *common.Config) *Services { @@ -79,7 +90,7 @@ func NewServices(hostAddrHTTP string, hostAddrWS string, storage storage.Storage cfg := pool.NewDefaultPoolConfig() cfg.MaxTotal = 100 // todo - what is the right number - return &Services{ + services := Services{ HostAddrHTTP: hostAddrHTTP, HostAddrWS: hostAddrWS, Storage: storage, @@ -92,6 +103,26 @@ func NewServices(hostAddrHTTP string, hostAddrWS string, storage storage.Storage rpcWSConnPool: pool.NewObjectPool(context.Background(), factoryWS, cfg), Config: config, } + + connectionObj, err := services.rpcWSConnPool.BorrowObject(context.Background()) + if err != nil { + panic(fmt.Errorf("cannot fetch rpc connection to backend node %w", err)) + } + + rpcClient := connectionObj.(rpc.Client) + ch := make(chan *common2.BatchHeader) + clientSubscription, err := rpcClient.Subscribe(context.Background(), rpc.SubscribeNamespace, ch, rpc.SubscriptionTypeNewHeads) + if err != nil { + panic(fmt.Errorf("cannot subscribe to new heads to the backend %w", err)) + } + + services.backendNewHeadsSubscription = clientSubscription + services.NewHeadsService = subscriptioncommon.NewNewHeadsService(ch, true, logger, func(newHead *common2.BatchHeader) error { + // todo - in a followup PR, invalidate cache entries marked as "latest" + return nil + }) + + return &services } // IsStopping returns whether the WE is stopping @@ -235,3 +266,9 @@ func (w *Services) GenerateUserMessageToSign(encryptionToken []byte, formatsSlic } return string(message), nil } + +func (w *Services) Stop() { + w.backendNewHeadsSubscription.Unsubscribe() + w.rpcHTTPConnPool.Close(context.Background()) + w.rpcWSConnPool.Close(context.Background()) +} diff --git a/tools/walletextension/walletextension_container.go b/tools/walletextension/walletextension_container.go index 505832a98e..084044af54 100644 --- a/tools/walletextension/walletextension_container.go +++ b/tools/walletextension/walletextension_container.go @@ -5,6 +5,8 @@ import ( "os" "time" + "github.com/ten-protocol/go-ten/go/common/subscription" + "github.com/ten-protocol/go-ten/tools/walletextension/api" "github.com/ten-protocol/go-ten/tools/walletextension/httpapi" @@ -22,9 +24,11 @@ import ( ) type Container struct { - stopControl *stopcontrol.StopControl - logger gethlog.Logger - rpcServer node.Server + stopControl *stopcontrol.StopControl + logger gethlog.Logger + rpcServer node.Server + services *rpcapi.Services + newHeadsService *subscription.NewHeadsService } func NewContainerFromConfig(config wecommon.Config, logger gethlog.Logger) *Container { @@ -92,28 +96,23 @@ func NewContainerFromConfig(config wecommon.Config, logger gethlog.Logger) *Cont }, }}) - return NewWalletExtensionContainer( - stopControl, - rpcServer, - logger, - ) -} - -func NewWalletExtensionContainer( - stopControl *stopcontrol.StopControl, - rpcServer node.Server, - logger gethlog.Logger, -) *Container { return &Container{ - stopControl: stopControl, - rpcServer: rpcServer, - logger: logger, + stopControl: stopControl, + rpcServer: rpcServer, + newHeadsService: walletExt.NewHeadsService, + services: walletExt, + logger: logger, } } // Start starts the wallet extension container func (w *Container) Start() error { - err := w.rpcServer.Start() + err := w.newHeadsService.Start() + if err != nil { + return err + } + + err = w.rpcServer.Start() if err != nil { return err } @@ -122,6 +121,7 @@ func (w *Container) Start() error { func (w *Container) Stop() error { w.stopControl.Stop() + _ = w.newHeadsService.Stop() if w.rpcServer != nil { // rpc server cannot be stopped synchronously as it will kill current request @@ -132,5 +132,6 @@ func (w *Container) Stop() error { }() } + w.services.Stop() return nil }