From fab35cf300278b4f8f22f6fa9730a2c2aed0abfa Mon Sep 17 00:00:00 2001 From: Jacob Aronoff Date: Fri, 26 Jan 2024 14:48:03 -0500 Subject: [PATCH] Add client context propagation (#248) This is a follow up to #237 and #247, adding context propagation for client methods. **This involves a breaking change for the client interfaces** --- client/clientimpl_test.go | 12 +++--- client/internal/httpsender.go | 4 +- client/internal/httpsender_test.go | 8 ++-- client/internal/receivedprocessor.go | 13 +++--- client/internal/wsreceiver_test.go | 4 +- client/types/callbacks.go | 42 +++++++++---------- client/wsclient.go | 4 +- client/wsclient_test.go | 4 +- internal/examples/agent/agent/agent.go | 12 +++--- .../supervisor/supervisor/supervisor.go | 12 +++--- 10 files changed, 58 insertions(+), 57 deletions(-) diff --git a/client/clientimpl_test.go b/client/clientimpl_test.go index 3ab15692..56f27310 100644 --- a/client/clientimpl_test.go +++ b/client/clientimpl_test.go @@ -161,7 +161,7 @@ func TestOnConnectFail(t *testing.T) { var connectErr atomic.Value settings := createNoServerSettings() settings.Callbacks = types.CallbacksStruct{ - OnConnectFailedFunc: func(err error) { + OnConnectFailedFunc: func(ctx context.Context, err error) { connectErr.Store(err) }, } @@ -238,7 +238,7 @@ func TestConnectWithServer(t *testing.T) { var connected int64 settings := types.StartSettings{ Callbacks: types.CallbacksStruct{ - OnConnectFunc: func() { + OnConnectFunc: func(ctx context.Context) { atomic.StoreInt64(&connected, 1) }, }, @@ -276,11 +276,11 @@ func TestConnectWithServer503(t *testing.T) { var connectErr atomic.Value settings := types.StartSettings{ Callbacks: types.CallbacksStruct{ - OnConnectFunc: func() { + OnConnectFunc: func(ctx context.Context) { atomic.StoreInt64(&clientConnected, 1) assert.Fail(t, "Client should not be able to connect") }, - OnConnectFailedFunc: func(err error) { + OnConnectFailedFunc: func(ctx context.Context, err error) { connectErr.Store(err) }, }, @@ -405,7 +405,7 @@ func TestFirstStatusReport(t *testing.T) { var connected, remoteConfigReceived int64 settings := types.StartSettings{ Callbacks: types.CallbacksStruct{ - OnConnectFunc: func() { + OnConnectFunc: func(ctx context.Context) { atomic.AddInt64(&connected, 1) }, OnMessageFunc: func(ctx context.Context, msg *types.MessageData) { @@ -458,7 +458,7 @@ func TestIncludesDetailsOnReconnect(t *testing.T) { var connected int64 settings := types.StartSettings{ Callbacks: types.CallbacksStruct{ - OnConnectFunc: func() { + OnConnectFunc: func(ctx context.Context) { atomic.AddInt64(&connected, 1) }, }, diff --git a/client/internal/httpsender.go b/client/internal/httpsender.go index f5a50279..00516423 100644 --- a/client/internal/httpsender.go +++ b/client/internal/httpsender.go @@ -179,7 +179,7 @@ func (h *HTTPSender) sendRequestWithRetries(ctx context.Context) (*http.Response switch resp.StatusCode { case http.StatusOK: // We consider it connected if we receive 200 status from the Server. - h.callbacks.OnConnect() + h.callbacks.OnConnect(ctx) return resp, nil case http.StatusTooManyRequests, http.StatusServiceUnavailable: @@ -195,7 +195,7 @@ func (h *HTTPSender) sendRequestWithRetries(ctx context.Context) (*http.Response } h.logger.Errorf(ctx, "Failed to do HTTP request (%v), will retry", err) - h.callbacks.OnConnectFailed(err) + h.callbacks.OnConnectFailed(ctx, err) } case <-ctx.Done(): diff --git a/client/internal/httpsender_test.go b/client/internal/httpsender_test.go index c8e2855e..420bc9e1 100644 --- a/client/internal/httpsender_test.go +++ b/client/internal/httpsender_test.go @@ -47,9 +47,9 @@ func TestHTTPSenderRetryForStatusTooManyRequests(t *testing.T) { } }) sender.callbacks = types.CallbacksStruct{ - OnConnectFunc: func() { + OnConnectFunc: func(ctx context.Context) { }, - OnConnectFailedFunc: func(_ error) { + OnConnectFailedFunc: func(ctx context.Context, _ error) { }, } sender.url = url @@ -144,9 +144,9 @@ func TestHTTPSenderRetryForFailedRequests(t *testing.T) { } }) sender.callbacks = types.CallbacksStruct{ - OnConnectFunc: func() { + OnConnectFunc: func(ctx context.Context) { }, - OnConnectFailedFunc: func(_ error) { + OnConnectFailedFunc: func(ctx context.Context, _ error) { }, } sender.url = url diff --git a/client/internal/receivedprocessor.go b/client/internal/receivedprocessor.go index 58419041..fc0ab3d6 100644 --- a/client/internal/receivedprocessor.go +++ b/client/internal/receivedprocessor.go @@ -56,7 +56,7 @@ func (r *receivedProcessor) ProcessReceivedMessage(ctx context.Context, msg *pro // to process. if msg.Command != nil { if r.hasCapability(protobufs.AgentCapabilities_AgentCapabilities_AcceptsRestartCommand) { - r.rcvCommand(msg.Command) + r.rcvCommand(ctx, msg.Command) // If a command message exists, other messages will be ignored return } else { @@ -198,7 +198,7 @@ func (r *receivedProcessor) rcvOpampConnectionSettings(ctx context.Context, sett err := r.callbacks.OnOpampConnectionSettings(ctx, settings.Opamp) if err == nil { // TODO: verify connection using new settings. - r.callbacks.OnOpampConnectionSettingsAccepted(settings.Opamp) + r.callbacks.OnOpampConnectionSettingsAccepted(ctx, settings.Opamp) } } else { r.logger.Debugf(ctx, "Ignoring Opamp, agent does not have AcceptsOpAMPConnectionSettings capability") @@ -206,8 +206,9 @@ func (r *receivedProcessor) rcvOpampConnectionSettings(ctx context.Context, sett } func (r *receivedProcessor) processErrorResponse(ctx context.Context, body *protobufs.ServerErrorResponse) { - // TODO: implement this. - r.logger.Errorf(ctx, "received an error from server: %s", body.ErrorMessage) + if body != nil { + r.callbacks.OnError(ctx, body) + } } func (r *receivedProcessor) rcvAgentIdentification(ctx context.Context, agentId *protobufs.AgentIdentification) error { @@ -226,8 +227,8 @@ func (r *receivedProcessor) rcvAgentIdentification(ctx context.Context, agentId return nil } -func (r *receivedProcessor) rcvCommand(command *protobufs.ServerToAgentCommand) { +func (r *receivedProcessor) rcvCommand(ctx context.Context, command *protobufs.ServerToAgentCommand) { if command != nil { - r.callbacks.OnCommand(command) + r.callbacks.OnCommand(ctx, command) } } diff --git a/client/internal/wsreceiver_test.go b/client/internal/wsreceiver_test.go index b428879c..c929f8fa 100644 --- a/client/internal/wsreceiver_test.go +++ b/client/internal/wsreceiver_test.go @@ -72,7 +72,7 @@ func TestServerToAgentCommand(t *testing.T) { action := none callbacks := types.CallbacksStruct{ - OnCommandFunc: func(command *protobufs.ServerToAgentCommand) error { + OnCommandFunc: func(ctx context.Context, command *protobufs.ServerToAgentCommand) error { switch command.Type { case protobufs.CommandType_CommandType_Restart: action = restart @@ -132,7 +132,7 @@ func TestServerToAgentCommandExclusive(t *testing.T) { calledOnMessageConfig := false callbacks := types.CallbacksStruct{ - OnCommandFunc: func(command *protobufs.ServerToAgentCommand) error { + OnCommandFunc: func(ctx context.Context, command *protobufs.ServerToAgentCommand) error { calledCommand = true return nil }, diff --git a/client/types/callbacks.go b/client/types/callbacks.go index 20f8fe32..9952ddd7 100644 --- a/client/types/callbacks.go +++ b/client/types/callbacks.go @@ -40,23 +40,24 @@ type MessageData struct { } // Callbacks is an interface for the Client to handle messages from the Server. +// Callbacks are expected to honour the context passed to them, meaning they should be aware of cancellations. type Callbacks interface { // OnConnect is called when the connection is successfully established to the Server. // May be called after Start() is called and every time a connection is established to the Server. // For WebSocket clients this is called after the handshake is completed without any error. // For HTTP clients this is called for any request if the response status is OK. - OnConnect() + OnConnect(ctx context.Context) // OnConnectFailed is called when the connection to the Server cannot be established. // May be called after Start() is called and tries to connect to the Server. // May also be called if the connection is lost and reconnection attempt fails. - OnConnectFailed(err error) + OnConnectFailed(ctx context.Context, err error) // OnError is called when the Server reports an error in response to some previously // sent request. Useful for logging purposes. The Agent should not attempt to process // the error by reconnecting or retrying previous operations. The client handles the // ErrorResponse_UNAVAILABLE case internally by performing retries as necessary. - OnError(err *protobufs.ServerErrorResponse) + OnError(ctx context.Context, err *protobufs.ServerErrorResponse) // OnMessage is called when the Agent receives a message that needs processing. // See MessageData definition for the data that may be available for processing. @@ -94,9 +95,7 @@ type Callbacks interface { // verified and accepted (OnOpampConnectionSettingsOffer and connection using // new settings succeeds). The Agent should store the settings and use them // in the future. Old connection settings should be forgotten. - OnOpampConnectionSettingsAccepted( - settings *protobufs.OpAMPConnectionSettings, - ) + OnOpampConnectionSettingsAccepted(ctx context.Context, settings *protobufs.OpAMPConnectionSettings) // For all methods that accept a context parameter the caller may cancel the // context if processing takes too long. In that case the method should return @@ -115,15 +114,15 @@ type Callbacks interface { GetEffectiveConfig(ctx context.Context) (*protobufs.EffectiveConfig, error) // OnCommand is called when the Server requests that the connected Agent perform a command. - OnCommand(command *protobufs.ServerToAgentCommand) error + OnCommand(ctx context.Context, command *protobufs.ServerToAgentCommand) error } // CallbacksStruct is a struct that implements Callbacks interface and allows // to override only the methods that are needed. If a method is not overridden then it is a no-op. type CallbacksStruct struct { - OnConnectFunc func() - OnConnectFailedFunc func(err error) - OnErrorFunc func(err *protobufs.ServerErrorResponse) + OnConnectFunc func(ctx context.Context) + OnConnectFailedFunc func(ctx context.Context, err error) + OnErrorFunc func(ctx context.Context, err *protobufs.ServerErrorResponse) OnMessageFunc func(ctx context.Context, msg *MessageData) @@ -132,10 +131,11 @@ type CallbacksStruct struct { settings *protobufs.OpAMPConnectionSettings, ) error OnOpampConnectionSettingsAcceptedFunc func( + ctx context.Context, settings *protobufs.OpAMPConnectionSettings, ) - OnCommandFunc func(command *protobufs.ServerToAgentCommand) error + OnCommandFunc func(ctx context.Context, command *protobufs.ServerToAgentCommand) error SaveRemoteConfigStatusFunc func(ctx context.Context, status *protobufs.RemoteConfigStatus) GetEffectiveConfigFunc func(ctx context.Context) (*protobufs.EffectiveConfig, error) @@ -144,23 +144,23 @@ type CallbacksStruct struct { var _ Callbacks = (*CallbacksStruct)(nil) // OnConnect implements Callbacks.OnConnect. -func (c CallbacksStruct) OnConnect() { +func (c CallbacksStruct) OnConnect(ctx context.Context) { if c.OnConnectFunc != nil { - c.OnConnectFunc() + c.OnConnectFunc(ctx) } } // OnConnectFailed implements Callbacks.OnConnectFailed. -func (c CallbacksStruct) OnConnectFailed(err error) { +func (c CallbacksStruct) OnConnectFailed(ctx context.Context, err error) { if c.OnConnectFailedFunc != nil { - c.OnConnectFailedFunc(err) + c.OnConnectFailedFunc(ctx, err) } } // OnError implements Callbacks.OnError. -func (c CallbacksStruct) OnError(err *protobufs.ServerErrorResponse) { +func (c CallbacksStruct) OnError(ctx context.Context, err *protobufs.ServerErrorResponse) { if c.OnErrorFunc != nil { - c.OnErrorFunc(err) + c.OnErrorFunc(ctx, err) } } @@ -197,16 +197,16 @@ func (c CallbacksStruct) OnOpampConnectionSettings( } // OnOpampConnectionSettingsAccepted implements Callbacks.OnOpampConnectionSettingsAccepted. -func (c CallbacksStruct) OnOpampConnectionSettingsAccepted(settings *protobufs.OpAMPConnectionSettings) { +func (c CallbacksStruct) OnOpampConnectionSettingsAccepted(ctx context.Context, settings *protobufs.OpAMPConnectionSettings) { if c.OnOpampConnectionSettingsAcceptedFunc != nil { - c.OnOpampConnectionSettingsAcceptedFunc(settings) + c.OnOpampConnectionSettingsAcceptedFunc(ctx, settings) } } // OnCommand implements Callbacks.OnCommand. -func (c CallbacksStruct) OnCommand(command *protobufs.ServerToAgentCommand) error { +func (c CallbacksStruct) OnCommand(ctx context.Context, command *protobufs.ServerToAgentCommand) error { if c.OnCommandFunc != nil { - return c.OnCommandFunc(command) + return c.OnCommandFunc(ctx, command) } return nil } diff --git a/client/wsclient.go b/client/wsclient.go index 15b790aa..268ffddd 100644 --- a/client/wsclient.go +++ b/client/wsclient.go @@ -128,7 +128,7 @@ func (c *wsClient) tryConnectOnce(ctx context.Context) (err error, retryAfter sh conn, resp, err := c.dialer.DialContext(ctx, c.url.String(), c.requestHeader) if err != nil { if c.common.Callbacks != nil && !c.common.IsStopping() { - c.common.Callbacks.OnConnectFailed(err) + c.common.Callbacks.OnConnectFailed(ctx, err) } if resp != nil { c.common.Logger.Errorf(ctx, "Server responded with status=%v", resp.Status) @@ -143,7 +143,7 @@ func (c *wsClient) tryConnectOnce(ctx context.Context) (err error, retryAfter sh c.conn = conn c.connMutex.Unlock() if c.common.Callbacks != nil { - c.common.Callbacks.OnConnect() + c.common.Callbacks.OnConnect(ctx) } return nil, sharedinternal.OptionalDuration{Defined: false} diff --git a/client/wsclient_test.go b/client/wsclient_test.go index 715b140f..7696e9e6 100644 --- a/client/wsclient_test.go +++ b/client/wsclient_test.go @@ -31,10 +31,10 @@ func TestDisconnectWSByServer(t *testing.T) { var connectErr atomic.Value settings := types.StartSettings{ Callbacks: types.CallbacksStruct{ - OnConnectFunc: func() { + OnConnectFunc: func(ctx context.Context) { atomic.StoreInt64(&connected, 1) }, - OnConnectFailedFunc: func(err error) { + OnConnectFailedFunc: func(ctx context.Context, err error) { connectErr.Store(err) }, }, diff --git a/internal/examples/agent/agent/agent.go b/internal/examples/agent/agent/agent.go index 8dfbed68..bc89c3b8 100644 --- a/internal/examples/agent/agent/agent.go +++ b/internal/examples/agent/agent/agent.go @@ -109,14 +109,14 @@ func (agent *Agent) connect() error { TLSConfig: tlsConfig, InstanceUid: agent.instanceId.String(), Callbacks: types.CallbacksStruct{ - OnConnectFunc: func() { - agent.logger.Debugf(context.Background(), "Connected to the server.") + OnConnectFunc: func(ctx context.Context) { + agent.logger.Debugf(ctx, "Connected to the server.") }, - OnConnectFailedFunc: func(err error) { - agent.logger.Errorf(context.Background(), "Failed to connect to the server: %v", err) + OnConnectFailedFunc: func(ctx context.Context, err error) { + agent.logger.Errorf(ctx, "Failed to connect to the server: %v", err) }, - OnErrorFunc: func(err *protobufs.ServerErrorResponse) { - agent.logger.Errorf(context.Background(), "Server returned an error response: %v", err.ErrorMessage) + OnErrorFunc: func(ctx context.Context, err *protobufs.ServerErrorResponse) { + agent.logger.Errorf(ctx, "Server returned an error response: %v", err.ErrorMessage) }, SaveRemoteConfigStatusFunc: func(_ context.Context, status *protobufs.RemoteConfigStatus) { agent.remoteConfigStatus = status diff --git a/internal/examples/supervisor/supervisor/supervisor.go b/internal/examples/supervisor/supervisor/supervisor.go index 9d2c577b..c47d9122 100644 --- a/internal/examples/supervisor/supervisor/supervisor.go +++ b/internal/examples/supervisor/supervisor/supervisor.go @@ -140,14 +140,14 @@ func (s *Supervisor) startOpAMP() error { }, InstanceUid: s.instanceId.String(), Callbacks: types.CallbacksStruct{ - OnConnectFunc: func() { - s.logger.Debugf(context.Background(), "Connected to the server.") + OnConnectFunc: func(ctx context.Context) { + s.logger.Debugf(ctx, "Connected to the server.") }, - OnConnectFailedFunc: func(err error) { - s.logger.Errorf(context.Background(), "Failed to connect to the server: %v", err) + OnConnectFailedFunc: func(ctx context.Context, err error) { + s.logger.Errorf(ctx, "Failed to connect to the server: %v", err) }, - OnErrorFunc: func(err *protobufs.ServerErrorResponse) { - s.logger.Errorf(context.Background(), "Server returned an error response: %v", err.ErrorMessage) + OnErrorFunc: func(ctx context.Context, err *protobufs.ServerErrorResponse) { + s.logger.Errorf(ctx, "Server returned an error response: %v", err.ErrorMessage) }, GetEffectiveConfigFunc: func(ctx context.Context) (*protobufs.EffectiveConfig, error) { return s.createEffectiveConfigMsg(), nil