diff --git a/metering/metering.go b/metering/metering.go index 57c6d5b8..7fb11048 100644 --- a/metering/metering.go +++ b/metering/metering.go @@ -113,7 +113,7 @@ func NewMetricsSender() *MetricsSender { } } -func (ms *MetricsSender) Send(ctx context.Context, userID, apiKeyID, ip, userMeta, endpoint string, resp proto.Message) { +func (ms *MetricsSender) Send(ctx context.Context, userID, apiKeyID, ip, userMeta, outputModuleHash, endpoint string, resp proto.Message) { ms.Lock() defer ms.Unlock() @@ -121,8 +121,6 @@ func (ms *MetricsSender) Send(ctx context.Context, userID, apiKeyID, ip, userMet endpoint = fmt.Sprintf("%s%s", endpoint, "Backfill") } - outputModuleHash := reqctx.OutputModuleHash(ctx) - meter := dmetering.GetBytesMeter(ctx) bytesRead := meter.BytesReadDelta() diff --git a/metering/metering_test.go b/metering/metering_test.go index 7b893ecf..afefdb78 100644 --- a/metering/metering_test.go +++ b/metering/metering_test.go @@ -386,8 +386,10 @@ func TestSend(t *testing.T) { metericsSender := NewMetricsSender() + outputModuleHash := "outputModuleHash" + // Call the Send function - metericsSender.Send(ctx, "user1", "apiKey1", "127.0.0.1", "meta", "endpoint", resp) + metericsSender.Send(ctx, "user1", "apiKey1", "127.0.0.1", "meta", outputModuleHash, "endpoint", resp) // Verify the emitted event assert.Len(t, emitter.events, 1) @@ -398,6 +400,7 @@ func TestSend(t *testing.T) { assert.Equal(t, "127.0.0.1", event.IpAddress) assert.Equal(t, "meta", event.Meta) assert.Equal(t, "endpoint", event.Endpoint) + assert.Equal(t, "outputModuleHash", event.OutputModuleHash) assert.Equal(t, float64(proto.Size(resp)), event.Metrics["egress_bytes"]) assert.Equal(t, float64(0), event.Metrics["written_bytes"]) assert.Equal(t, float64(0), event.Metrics["read_bytes"]) @@ -448,7 +451,7 @@ func TestSendParallel(t *testing.T) { meter.CountInc(MeterFileCompressedWriteBytes, 600) time.Sleep(time.Duration(randomInt()) * time.Nanosecond) - metricsSender.Send(ctx, "user1", "apiKey1", "127.0.0.1", "meta", "endpoint", resp) + metricsSender.Send(ctx, "user1", "apiKey1", "127.0.0.1", "meta", "outputModuleHash", "endpoint", resp) }() } diff --git a/service/tier1.go b/service/tier1.go index e8ddc118..545f7e6b 100644 --- a/service/tier1.go +++ b/service/tier1.go @@ -209,16 +209,6 @@ func (s *Tier1Service) Blocks( ctx, span := reqctx.WithSpan(ctx, "substreams/tier1/request") defer span.EndWithErr(&err) - // We need to ensure that the response function is NEVER used after this Blocks handler has returned. - // We use a context that will be canceled on defer, and a lock to prevent races. The respFunc is used in various threads - mut := sync.Mutex{} - respContext, cancel := context.WithCancel(ctx) - defer func() { - mut.Lock() - cancel() - mut.Unlock() - }() - request := req.Msg if request.Modules == nil { return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("missing modules in request")) @@ -231,6 +221,16 @@ func (s *Tier1Service) Blocks( outputModuleHash := execGraph.ModuleHashes().Get(request.OutputModule) ctx = reqctx.WithOutputModuleHash(ctx, outputModuleHash) + // We need to ensure that the response function is NEVER used after this Blocks handler has returned. + // We use a context that will be canceled on defer, and a lock to prevent races. The respFunc is used in various threads + mut := sync.Mutex{} + respContext, cancel := context.WithCancel(ctx) + defer func() { + mut.Lock() + cancel() + mut.Unlock() + }() + respFunc := tier1ResponseHandler(respContext, &mut, logger, stream) span.SetAttributes(attribute.Int64("substreams.tier", 1)) @@ -608,6 +608,8 @@ func tier1ResponseHandler(ctx context.Context, mut *sync.Mutex, logger *zap.Logg userMeta := auth.Meta() ip := auth.RealIP() + outputModuleHash := reqctx.OutputModuleHash(ctx) + ctx = reqctx.WithEmitter(ctx, dmetering.GetDefaultEmitter()) metericsSender := metering.GetMetricsSender(ctx) @@ -625,7 +627,7 @@ func tier1ResponseHandler(ctx context.Context, mut *sync.Mutex, logger *zap.Logg return connect.NewError(connect.CodeUnavailable, err) } - metericsSender.Send(ctx, userID, apiKeyID, ip, userMeta, "sf.substreams.rpc.v2/Blocks", resp) + metericsSender.Send(ctx, userID, apiKeyID, ip, userMeta, outputModuleHash, "sf.substreams.rpc.v2/Blocks", resp) return nil } } diff --git a/service/tier2.go b/service/tier2.go index 4cf05aca..e649501e 100644 --- a/service/tier2.go +++ b/service/tier2.go @@ -514,6 +514,7 @@ func tier2ResponseHandler(ctx context.Context, logger *zap.Logger, streamSrv pbs logger.Warn("no auth information available in tier2 response handler") } + outputModuleHash := reqctx.OutputModuleHash(ctx) metricsSender := metering.GetMetricsSender(ctx) return func(respAny substreams.ResponseFromAnyTier) error { @@ -530,7 +531,7 @@ func tier2ResponseHandler(ctx context.Context, logger *zap.Logger, streamSrv pbs zap.String("user_meta", userMeta), zap.String("endpoint", "sf.substreams.internal.v2/ProcessRange"), ) - metricsSender.Send(ctx, userID, apiKeyID, ip, userMeta, "sf.substreams.internal.v2/ProcessRange", resp) + metricsSender.Send(ctx, userID, apiKeyID, ip, userMeta, outputModuleHash, "sf.substreams.internal.v2/ProcessRange", resp) return nil } } diff --git a/test/collector_test.go b/test/collector_test.go index d47b6579..3921815e 100644 --- a/test/collector_test.go +++ b/test/collector_test.go @@ -64,10 +64,10 @@ func (c *responseCollector) Collect(respAny substreams.ResponseFromAnyTier) erro switch resp := respAny.(type) { case *pbsubstreamsrpc.Response: c.responses = append(c.responses, resp) - c.sender.Send(c.ctx, "test_user", "test_api_key", "10.0.0.1", "test_meta", "tier1", resp) + c.sender.Send(c.ctx, "test_user", "test_api_key", "10.0.0.1", "test_meta", "testOutputHash", "tier1", resp) case *pbssinternal.ProcessRangeResponse: c.internalResponses = append(c.internalResponses, resp) - c.sender.Send(c.ctx, "test_user", "test_api_key", "10.0.0.1", "test_meta", "tier2", resp) + c.sender.Send(c.ctx, "test_user", "test_api_key", "10.0.0.1", "test_meta", "testOutputHash", "tier2", resp) } return nil }