diff --git a/go/common/host/services.go b/go/common/host/services.go index 3c1d532704..2ee5bbee1b 100644 --- a/go/common/host/services.go +++ b/go/common/host/services.go @@ -123,7 +123,7 @@ type L1Publisher interface { // L2BatchRepository provides an interface for the host to request L2 batch data (live-streaming and historical) type L2BatchRepository interface { // Subscribe will register a batch handler to receive new batches as they arrive - Subscribe(handler L2BatchHandler) + Subscribe(handler L2BatchHandler) func() FetchBatchBySeqNo(seqNo *big.Int) (*common.ExtBatch, error) diff --git a/go/host/enclave/guardian.go b/go/host/enclave/guardian.go index 40eef51108..8584a1dc43 100644 --- a/go/host/enclave/guardian.go +++ b/go/host/enclave/guardian.go @@ -124,6 +124,8 @@ func (g *Guardian) Start() error { // subscribe for L1 and P2P data g.sl.P2P().SubscribeForTx(g) + + // note: not keeping the unsubscribe functions because the lifespan of the guardian is the same as the host g.sl.L1Repo().Subscribe(g) g.sl.L2Repo().Subscribe(g) diff --git a/go/host/l2/batchrepository.go b/go/host/l2/batchrepository.go index e6f6163059..d833ad415f 100644 --- a/go/host/l2/batchrepository.go +++ b/go/host/l2/batchrepository.go @@ -13,6 +13,7 @@ import ( "github.com/ten-protocol/go-ten/go/common/errutil" "github.com/ten-protocol/go-ten/go/common/host" "github.com/ten-protocol/go-ten/go/common/log" + "github.com/ten-protocol/go-ten/go/common/subscription" "github.com/ten-protocol/go-ten/go/config" "github.com/ten-protocol/go-ten/go/host/storage" ) @@ -34,7 +35,7 @@ type batchRepoServiceLocator interface { // Repository is responsible for storing and retrieving batches from the database // If it can't find a batch it will request it from peers. It also subscribes for batch requests from peers and responds to them. type Repository struct { - subscribers []host.L2BatchHandler + batchSubscribers *subscription.Manager[host.L2BatchHandler] sl batchRepoServiceLocator storage storage.Storage @@ -57,6 +58,7 @@ type Repository struct { func NewBatchRepository(cfg *config.HostConfig, hostService batchRepoServiceLocator, storage storage.Storage, logger gethlog.Logger) *Repository { return &Repository{ + batchSubscribers: subscription.NewManager[host.L2BatchHandler](), sl: hostService, storage: storage, isSequencer: cfg.NodeType == common.Sequencer, @@ -147,9 +149,9 @@ func (r *Repository) HandleBatchRequest(requesterID string, fromSeqNo *big.Int) } } -// Subscribe registers a handler to be notified of new head batches as they arrive -func (r *Repository) Subscribe(subscriber host.L2BatchHandler) { - r.subscribers = append(r.subscribers, subscriber) +// Subscribe registers a handler to be notified of new head batches as they arrive, returns unsubscribe func +func (r *Repository) Subscribe(handler host.L2BatchHandler) func() { + return r.batchSubscribers.Subscribe(handler) } func (r *Repository) FetchBatchBySeqNo(seqNo *big.Int) (*common.ExtBatch, error) { @@ -185,6 +187,10 @@ func (r *Repository) AddBatch(batch *common.ExtBatch) error { defer r.latestSeqNoMutex.Unlock() if batch.Header.SequencerOrderNo.Cmp(r.latestBatchSeqNo) > 0 { r.latestBatchSeqNo = batch.Header.SequencerOrderNo + // notify subscribers, a new batch has been successfully added to the db + for _, subscriber := range r.batchSubscribers.Subscribers() { + go subscriber.HandleBatch(batch) + } } return nil } diff --git a/tools/walletextension/httpapi/utils.go b/tools/walletextension/httpapi/utils.go index c7acc16b8c..e367856abc 100644 --- a/tools/walletextension/httpapi/utils.go +++ b/tools/walletextension/httpapi/utils.go @@ -25,10 +25,13 @@ func getUserID(conn UserConn) ([]byte, error) { // try getting userID (`token`) from query parameters and return it if successful userID, err := getQueryParameter(conn.ReadRequestParams(), common.EncryptedTokenQueryParameter) if err == nil { - if len(userID) != common.MessageUserIDLenWithPrefix { - return nil, fmt.Errorf(fmt.Sprintf("wrong length of userID from URL. Got: %d, Expected: %d", len(userID), common.MessageUserIDLenWithPrefix)) + if len(userID) == common.MessageUserIDLenWithPrefix { + return hexutils.HexToBytes(userID[2:]), nil + } else if len(userID) == common.MessageUserIDLen { + return hexutils.HexToBytes(userID), nil } - return hexutils.HexToBytes(userID[2:]), err + + return nil, fmt.Errorf(fmt.Sprintf("wrong length of userID from URL. Got: %d, Expected: %d od %d", len(userID), common.MessageUserIDLenWithPrefix, common.MessageUserIDLen)) } return nil, fmt.Errorf("missing token field")