From b1e0d82f4348c0da0678a91d2e0d818a290b8d22 Mon Sep 17 00:00:00 2001 From: Nikolay Edigaryev Date: Sat, 21 Nov 2020 23:25:58 +0300 Subject: [PATCH] cirrus worker run: ability to override RPC endpoint (#173) * cirrus worker run: ability to override RPC endpoint * Make RPC endpoint flag similar to agent's -api-endpoint --- go.mod | 2 +- go.sum | 4 ++-- internal/commands/worker/run.go | 18 ++++++++++++++++-- internal/worker/options.go | 6 ------ internal/worker/task.go | 9 ++------- internal/worker/worker.go | 11 ++++++++--- internal/worker/worker_test.go | 3 +-- 7 files changed, 30 insertions(+), 23 deletions(-) diff --git a/go.mod b/go.mod index 208a0fc7..3022a33c 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/avast/retry-go v3.0.0+incompatible github.com/bmatcuk/doublestar v1.3.2 github.com/certifi/gocertifi v0.0.0-20200922220541-2c3bb06c6054 - github.com/cirruslabs/cirrus-ci-agent v1.20.0 + github.com/cirruslabs/cirrus-ci-agent v1.21.0 github.com/cirruslabs/echelon v1.4.0 github.com/cirruslabs/podmanapi v0.1.0 github.com/containerd/containerd v1.4.1 // indirect diff --git a/go.sum b/go.sum index e70c4558..51b663e1 100644 --- a/go.sum +++ b/go.sum @@ -76,8 +76,8 @@ github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghf github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= -github.com/cirruslabs/cirrus-ci-agent v1.20.0 h1:b38J1H1MDCxZIg32cY9+9VFLvUyx+tY0hkXfmUAfNs8= -github.com/cirruslabs/cirrus-ci-agent v1.20.0/go.mod h1:ga2zCGBfC+/+tnlHWa5cHwEuBPnhFuaf1PgTpWPGJWo= +github.com/cirruslabs/cirrus-ci-agent v1.21.0 h1:cbNgskPq+frZFwEg5eD71txM/fYBJmQSaGXCGRJy8Ss= +github.com/cirruslabs/cirrus-ci-agent v1.21.0/go.mod h1:ga2zCGBfC+/+tnlHWa5cHwEuBPnhFuaf1PgTpWPGJWo= github.com/cirruslabs/cirrus-ci-annotations v0.0.0-20200908203753-b813f63941d7/go.mod h1:98qD7HLlBx5aNqWiCH80OTTqTTsbXT69wxnlnrnoL0E= github.com/cirruslabs/echelon v1.4.0 h1:xubCf8BLFEBl1kamBZ1zjBrcw5p4z4anvJUBeR3E5YY= github.com/cirruslabs/echelon v1.4.0/go.mod h1:1jFBACMy3tzodXyTtNNLN9bw6UUU7Xpq9tYMRydehtY= diff --git a/internal/commands/worker/run.go b/internal/commands/worker/run.go index 3ebaac1a..64977b07 100644 --- a/internal/commands/worker/run.go +++ b/internal/commands/worker/run.go @@ -15,6 +15,9 @@ var ( name string token string labels map[string]string + + // RPC-related variables. + rpcEndpointAddress string ) func run(cmd *cobra.Command, args []string) error { @@ -26,11 +29,17 @@ func run(cmd *cobra.Command, args []string) error { } } - worker, err := worker.New( + opts := []worker.Option{ worker.WithName(viper.GetString("name")), worker.WithRegistrationToken(viper.GetString("token")), worker.WithLabels(viper.GetStringMapString("labels")), - ) + } + + if rpcEndpointAddress != "" { + opts = append(opts, worker.WithRPCEndpoint(rpcEndpointAddress)) + } + + worker, err := worker.New(opts...) if err != nil { return err } @@ -61,5 +70,10 @@ func NewRunCmd() *cobra.Command { "additional labels to use (e.g. --labels distro=debian)") _ = viper.BindPFlag("labels", cmd.PersistentFlags().Lookup("labels")) + // RPC-related variables + cmd.PersistentFlags().StringVar(&rpcEndpointAddress, "rpc-endpoint", worker.DefaultRPCEndpoint, "RPC endpoint address") + _ = viper.BindPFlag("rpc.endpoint", cmd.PersistentFlags().Lookup("rpc-endpoint")) + _ = cmd.PersistentFlags().MarkHidden("rpc-endpoint") + return cmd } diff --git a/internal/worker/options.go b/internal/worker/options.go index 3c580a0b..3981403a 100644 --- a/internal/worker/options.go +++ b/internal/worker/options.go @@ -25,9 +25,3 @@ func WithRPCEndpoint(rpcEndpoint string) Option { e.rpcEndpoint = rpcEndpoint } } - -func WithRPCInsecure() Option { - return func(e *Worker) { - e.rpcInsecure = true - } -} diff --git a/internal/worker/task.go b/internal/worker/task.go index 28897df2..c719b133 100644 --- a/internal/worker/task.go +++ b/internal/worker/task.go @@ -35,15 +35,10 @@ func (worker *Worker) runTask(ctx context.Context, agentAwareTask *api.PollRespo return } - rpcPrefix := "https://" - if worker.rpcInsecure { - rpcPrefix = "http://" - } - if err := inst.Run(taskCtx, &instance.RunConfig{ ProjectDir: "", - ContainerEndpoint: rpcPrefix + worker.rpcEndpoint, - DirectEndpoint: rpcPrefix + worker.rpcEndpoint, + ContainerEndpoint: worker.rpcEndpoint, + DirectEndpoint: worker.rpcEndpoint, ServerSecret: agentAwareTask.ServerSecret, ClientSecret: agentAwareTask.ClientSecret, TaskID: agentAwareTask.TaskId, diff --git a/internal/worker/worker.go b/internal/worker/worker.go index a42ba746..5fa1041b 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/certifi/gocertifi" "github.com/cirruslabs/cirrus-ci-agent/api" + "github.com/cirruslabs/cirrus-ci-agent/pkg/grpchelper" "github.com/cirruslabs/cirrus-cli/internal/version" "github.com/sirupsen/logrus" "google.golang.org/grpc" @@ -17,7 +18,7 @@ import ( ) const ( - defaultRPCEndpoint = "grpc.cirrus-ci.com:443" + DefaultRPCEndpoint = "https://grpc.cirrus-ci.com:443" defaultPollIntervalSeconds = 10 ) @@ -28,6 +29,7 @@ var ( type Worker struct { rpcEndpoint string + rpcTarget string rpcInsecure bool rpcClient api.CirrusWorkersServiceClient @@ -46,7 +48,7 @@ type Worker struct { func New(opts ...Option) (*Worker, error) { worker := &Worker{ - rpcEndpoint: defaultRPCEndpoint, + rpcEndpoint: DefaultRPCEndpoint, userSpecifiedLabels: make(map[string]string), pollIntervalSeconds: defaultPollIntervalSeconds, @@ -62,6 +64,9 @@ func New(opts ...Option) (*Worker, error) { opt(worker) } + // Parse endpoint + worker.rpcTarget, worker.rpcInsecure = grpchelper.TransportSettings(worker.rpcEndpoint) + if worker.registrationToken == "" { return nil, fmt.Errorf("%w: must provide a registration token", ErrWorker) } @@ -123,7 +128,7 @@ func (worker *Worker) Run(ctx context.Context) error { } // https://github.com/grpc/grpc-go/blob/master/Documentation/concurrency.md - conn, err := grpc.DialContext(subCtx, worker.rpcEndpoint, rpcSecurity) + conn, err := grpc.DialContext(subCtx, worker.rpcTarget, rpcSecurity) if err != nil { worker.logger.Errorf("failed to dial %s: %v", worker.rpcEndpoint, err) } diff --git a/internal/worker/worker_test.go b/internal/worker/worker_test.go index e2802039..c0e8d976 100644 --- a/internal/worker/worker_test.go +++ b/internal/worker/worker_test.go @@ -76,8 +76,7 @@ func TestWorker(t *testing.T) { // Start the worker worker, err := worker.New( worker.WithRegistrationToken(registrationToken), - worker.WithRPCEndpoint(lis.Addr().String()), - worker.WithRPCInsecure(), + worker.WithRPCEndpoint("http://"+lis.Addr().String()), ) if err != nil { t.Fatal(err)