diff --git a/debugd/internal/debugd/metadata/cloudprovider/cloudprovider.go b/debugd/internal/debugd/metadata/cloudprovider/cloudprovider.go index fa676537bc..64a19aa3d8 100644 --- a/debugd/internal/debugd/metadata/cloudprovider/cloudprovider.go +++ b/debugd/internal/debugd/metadata/cloudprovider/cloudprovider.go @@ -88,8 +88,8 @@ func (f *Fetcher) DiscoverDebugdIPs(ctx context.Context) ([]string, error) { return ips, nil } -// DiscoverLoadbalancerIP gets load balancer IP from metadata API. -func (f *Fetcher) DiscoverLoadbalancerIP(ctx context.Context) (string, error) { +// DiscoverLoadBalancerIP gets load balancer IP from metadata API. +func (f *Fetcher) DiscoverLoadBalancerIP(ctx context.Context) (string, error) { lbHost, _, err := f.metaAPI.GetLoadBalancerEndpoint(ctx) if err != nil { return "", fmt.Errorf("retrieving load balancer endpoint: %w", err) diff --git a/debugd/internal/debugd/metadata/cloudprovider/cloudprovider_test.go b/debugd/internal/debugd/metadata/cloudprovider/cloudprovider_test.go index 2868fecab3..53eb3ad8ed 100644 --- a/debugd/internal/debugd/metadata/cloudprovider/cloudprovider_test.go +++ b/debugd/internal/debugd/metadata/cloudprovider/cloudprovider_test.go @@ -121,7 +121,7 @@ func TestDiscoverDebugIPs(t *testing.T) { } } -func TestDiscoverLoadbalancerIP(t *testing.T) { +func TestDiscoverLoadBalancerIP(t *testing.T) { ip := "192.0.2.1" someErr := errors.New("failed") @@ -148,7 +148,7 @@ func TestDiscoverLoadbalancerIP(t *testing.T) { metaAPI: tc.metaAPI, } - ip, err := fetcher.DiscoverLoadbalancerIP(context.Background()) + ip, err := fetcher.DiscoverLoadBalancerIP(context.Background()) if tc.wantErr { assert.Error(err) diff --git a/debugd/internal/debugd/metadata/scheduler.go b/debugd/internal/debugd/metadata/scheduler.go index 3b7cac549c..eb04e5ade1 100644 --- a/debugd/internal/debugd/metadata/scheduler.go +++ b/debugd/internal/debugd/metadata/scheduler.go @@ -19,6 +19,7 @@ import ( // Fetcher retrieves other debugd IPs from cloud provider metadata. type Fetcher interface { DiscoverDebugdIPs(ctx context.Context) ([]string, error) + DiscoverLoadBalancerIP(ctx context.Context) (string, error) } // Scheduler schedules fetching of metadata using timers. @@ -51,23 +52,35 @@ func (s *Scheduler) Start(ctx context.Context, wg *sync.WaitGroup) { defer ticker.Stop() for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + ips, err := s.fetcher.DiscoverDebugdIPs(ctx) if err != nil { s.log.With(zap.Error(err)).Warnf("Discovering debugd IPs failed") } - if err == nil { - s.log.With(zap.Strings("ips", ips)).Infof("Discovered instances") - s.download(ctx, ips) - if s.deploymentDone && s.infoDone { - return - } + + lbip, err := s.fetcher.DiscoverLoadBalancerIP(ctx) + if err != nil { + s.log.With(zap.Error(err)).Warnf("Discovering load balancer IP failed") + } else { + ips = append(ips, lbip) } - select { - case <-ctx.Done(): + if len(ips) == 0 { + s.log.With(zap.Error(err)).Warnf("No debugd IPs discovered") + continue + } + + s.log.With(zap.Strings("ips", ips)).Infof("Discovered instances") + s.download(ctx, ips) + if s.deploymentDone && s.infoDone { return - case <-ticker.C: } + } }() } diff --git a/debugd/internal/debugd/metadata/scheduler_test.go b/debugd/internal/debugd/metadata/scheduler_test.go index 72dac43eff..8102079274 100644 --- a/debugd/internal/debugd/metadata/scheduler_test.go +++ b/debugd/internal/debugd/metadata/scheduler_test.go @@ -33,32 +33,47 @@ func TestSchedulerStart(t *testing.T) { wantInfoDownloads []string }{ "no errors occur": { - fetcher: stubFetcher{ips: []string{"192.0.2.1", "192.0.2.2"}}, + fetcher: stubFetcher{ips: []string{"192.0.2.1", "192.0.2.2"}, loadBalancerIP: "192.0.2.3"}, downloader: stubDownloader{}, - wantDiscoverCount: 1, + wantDiscoverCount: 2, wantDeploymentDownloads: []string{"192.0.2.1"}, wantInfoDownloads: []string{"192.0.2.1"}, }, - "download deployment fails": { + "no load balancer is discovered": { fetcher: stubFetcher{ips: []string{"192.0.2.1", "192.0.2.2"}}, - downloader: stubDownloader{downloadDeploymentErrs: []error{someErr, someErr}}, + downloader: stubDownloader{}, wantDiscoverCount: 2, - wantDeploymentDownloads: []string{"192.0.2.1", "192.0.2.2", "192.0.2.1"}, + wantDeploymentDownloads: []string{"192.0.2.1"}, wantInfoDownloads: []string{"192.0.2.1"}, }, - "download info fails": { - fetcher: stubFetcher{ips: []string{"192.0.2.1", "192.0.2.2"}}, - downloader: stubDownloader{downloadInfoErrs: []error{someErr, someErr}}, + "no nodes are discovered": { + fetcher: stubFetcher{loadBalancerIP: "192.0.2.3"}, + downloader: stubDownloader{}, wantDiscoverCount: 2, + wantDeploymentDownloads: []string{"192.0.2.3"}, + wantInfoDownloads: []string{"192.0.2.3"}, + }, + "download deployment fails": { + fetcher: stubFetcher{ips: []string{"192.0.2.1", "192.0.2.2"}, loadBalancerIP: "192.0.2.3"}, + downloader: stubDownloader{downloadDeploymentErrs: []error{someErr, someErr, someErr}}, + wantDiscoverCount: 4, + wantDeploymentDownloads: []string{"192.0.2.1", "192.0.2.2", "192.0.2.3", "192.0.2.1"}, + wantInfoDownloads: []string{"192.0.2.1"}, + }, + "download info fails": { + fetcher: stubFetcher{ips: []string{"192.0.2.1", "192.0.2.2"}, loadBalancerIP: "192.0.2.3"}, + downloader: stubDownloader{downloadInfoErrs: []error{someErr, someErr, someErr}}, + wantDiscoverCount: 4, wantDeploymentDownloads: []string{"192.0.2.1"}, - wantInfoDownloads: []string{"192.0.2.1", "192.0.2.2", "192.0.2.1"}, + wantInfoDownloads: []string{"192.0.2.1", "192.0.2.2", "192.0.2.3", "192.0.2.1"}, }, "endpoint discovery fails": { fetcher: stubFetcher{ - discoverErrs: []error{someErr, someErr, someErr}, - ips: []string{"192.0.2.1", "192.0.2.2"}, + discoverErrs: []error{someErr, someErr, someErr}, + discoverLoadBalancerIPErr: someErr, + ips: []string{"192.0.2.1", "192.0.2.2"}, }, - wantDiscoverCount: 4, + wantDiscoverCount: 8, wantDeploymentDownloads: []string{"192.0.2.1"}, wantInfoDownloads: []string{"192.0.2.1"}, }, @@ -90,7 +105,11 @@ type stubFetcher struct { ips []string discoverErrs []error discoverErrIdx int - discoverCalls int + + discoverCalls int + + loadBalancerIP string + discoverLoadBalancerIPErr error } func (s *stubFetcher) DiscoverDebugdIPs(_ context.Context) ([]string, error) { @@ -104,6 +123,11 @@ func (s *stubFetcher) DiscoverDebugdIPs(_ context.Context) ([]string, error) { return s.ips, nil } +func (s *stubFetcher) DiscoverLoadBalancerIP(_ context.Context) (string, error) { + s.discoverCalls++ + return s.loadBalancerIP, s.discoverLoadBalancerIPErr +} + type stubDownloader struct { downloadDeploymentErrs []error downloadDeploymentErrIdx int diff --git a/debugd/internal/debugd/server/server.go b/debugd/internal/debugd/server/server.go index f54af49708..551230ae71 100644 --- a/debugd/internal/debugd/server/server.go +++ b/debugd/internal/debugd/server/server.go @@ -133,9 +133,6 @@ func (s *debugdServer) UploadFiles(stream pb.Debugd_UploadFilesServer) error { // DownloadFiles streams the previously received files to other instances. func (s *debugdServer) DownloadFiles(_ *pb.DownloadFilesRequest, stream pb.Debugd_DownloadFilesServer) error { s.log.Infof("Sending files to other instance") - if !s.transfer.CanSend() { - return errors.New("cannot send files at this time") - } return s.transfer.SendFiles(stream) } @@ -185,5 +182,4 @@ type fileTransferer interface { RecvFiles(stream filetransfer.RecvFilesStream) error SendFiles(stream filetransfer.SendFilesStream) error GetFiles() []filetransfer.FileStat - CanSend() bool } diff --git a/debugd/internal/debugd/server/server_test.go b/debugd/internal/debugd/server/server_test.go index e596927e83..5837a5071d 100644 --- a/debugd/internal/debugd/server/server_test.go +++ b/debugd/internal/debugd/server/server_test.go @@ -228,10 +228,6 @@ func TestDownloadFiles(t *testing.T) { canSend: true, wantSendFileCalls: 1, }, - "transfer is not ready for sending": { - request: &pb.DownloadFilesRequest{}, - wantRecvErr: true, - }, } for name, tc := range testCases { diff --git a/debugd/internal/filetransfer/filetransfer.go b/debugd/internal/filetransfer/filetransfer.go index 13e266e3dc..ff90bdf09b 100644 --- a/debugd/internal/filetransfer/filetransfer.go +++ b/debugd/internal/filetransfer/filetransfer.go @@ -13,6 +13,7 @@ import ( "io" "io/fs" "sync" + "sync/atomic" "github.com/edgelesssys/constellation/v2/debugd/internal/debugd" "github.com/edgelesssys/constellation/v2/debugd/internal/filetransfer/streamer" @@ -33,10 +34,10 @@ type SendFilesStream interface { // FileTransferer manages sending and receiving of files. type FileTransferer struct { - mux sync.RWMutex + fileMux sync.RWMutex log *logger.Logger receiveStarted bool - receiveFinished bool + receiveFinished atomic.Bool files []FileStat streamer streamReadWriter showProgress bool @@ -52,12 +53,15 @@ func New(log *logger.Logger, streamer streamReadWriter, showProgress bool) *File } // SendFiles sends files to the given stream. +// If the FileTransferer has not received any files to send, an error is returned. func (s *FileTransferer) SendFiles(stream SendFilesStream) error { - s.mux.RLock() - defer s.mux.RUnlock() - if !s.receiveFinished { + if !s.receiveFinished.Load() { return errors.New("cannot send files before receiving them") } + + s.fileMux.RLock() + defer s.fileMux.RUnlock() + for _, file := range s.files { if err := s.handleFileSend(stream, file); err != nil { return err @@ -68,8 +72,8 @@ func (s *FileTransferer) SendFiles(stream SendFilesStream) error { // RecvFiles receives files from the given stream. func (s *FileTransferer) RecvFiles(stream RecvFilesStream) (err error) { - s.mux.Lock() - defer s.mux.Unlock() + s.fileMux.Lock() + defer s.fileMux.Unlock() if err := s.startRecv(); err != nil { return err } @@ -89,30 +93,23 @@ func (s *FileTransferer) RecvFiles(stream RecvFilesStream) (err error) { // GetFiles returns the a copy of the list of files that have been received. func (s *FileTransferer) GetFiles() []FileStat { - s.mux.RLock() - defer s.mux.RUnlock() + s.fileMux.RLock() + defer s.fileMux.RUnlock() res := make([]FileStat, len(s.files)) copy(res, s.files) return res } // SetFiles sets the list of files that can be sent. +// This function is used for a sender which has not received any files through +// this FileTransferer i.e. the CLI. func (s *FileTransferer) SetFiles(files []FileStat) { - s.mux.Lock() - defer s.mux.Unlock() + s.fileMux.Lock() + defer s.fileMux.Unlock() res := make([]FileStat, len(files)) copy(res, files) s.files = res - s.receiveFinished = true -} - -// CanSend returns true if the file receive has finished. -// This is called to determine if a debugd instance can request files from this server. -func (s *FileTransferer) CanSend() bool { - s.mux.RLock() - defer s.mux.RUnlock() - ret := s.receiveFinished - return ret + s.receiveFinished.Store(true) } func (s *FileTransferer) handleFileSend(stream SendFilesStream, file FileStat) error { @@ -173,7 +170,7 @@ func (s *FileTransferer) handleFileRecv(stream RecvFilesStream) (bool, error) { // startRecv marks the file receive as started. It returns an error if receiving has already started. func (s *FileTransferer) startRecv() error { switch { - case s.receiveFinished: + case s.receiveFinished.Load(): return ErrReceiveFinished case s.receiveStarted: return ErrReceiveRunning @@ -193,7 +190,7 @@ func (s *FileTransferer) abortRecv() { // This allows other debugd instances to request files from this server. func (s *FileTransferer) finishRecv() { s.receiveStarted = false - s.receiveFinished = true + s.receiveFinished.Store(true) } // addFile adds a file to the list of received files. diff --git a/debugd/internal/filetransfer/filetransfer_test.go b/debugd/internal/filetransfer/filetransfer_test.go index e643e98229..73a977e743 100644 --- a/debugd/internal/filetransfer/filetransfer_test.go +++ b/debugd/internal/filetransfer/filetransfer_test.go @@ -25,11 +25,12 @@ func TestMain(m *testing.M) { func TestSendFiles(t *testing.T) { testCases := map[string]struct { - files *[]FileStat - sendErr error - readStreamErr error - wantHeaders []*pb.FileTransferMessage - wantErr bool + files *[]FileStat + receiveFinished bool + sendErr error + readStreamErr error + wantHeaders []*pb.FileTransferMessage + wantErr bool }{ "can send files": { files: &[]FileStat{ @@ -44,6 +45,7 @@ func TestSendFiles(t *testing.T) { OverrideServiceUnit: "somesvcB", }, }, + receiveFinished: true, wantHeaders: []*pb.FileTransferMessage{ { Kind: &pb.FileTransferMessage_Header{ @@ -65,8 +67,21 @@ func TestSendFiles(t *testing.T) { }, }, }, - "no files set": { - wantErr: true, + "not finished receiving": { + files: &[]FileStat{ + { + TargetPath: "testfileA", + Mode: 0o644, + OverrideServiceUnit: "somesvcA", + }, + { + TargetPath: "testfileB", + Mode: 0o644, + OverrideServiceUnit: "somesvcB", + }, + }, + receiveFinished: false, + wantErr: true, }, "send fails": { files: &[]FileStat{ @@ -76,8 +91,9 @@ func TestSendFiles(t *testing.T) { OverrideServiceUnit: "somesvcA", }, }, - sendErr: errors.New("send failed"), - wantErr: true, + receiveFinished: true, + sendErr: errors.New("send failed"), + wantErr: true, }, "read stream fails": { files: &[]FileStat{ @@ -87,8 +103,9 @@ func TestSendFiles(t *testing.T) { OverrideServiceUnit: "somesvcA", }, }, - readStreamErr: errors.New("read stream failed"), - wantErr: true, + receiveFinished: true, + readStreamErr: errors.New("read stream failed"), + wantErr: true, }, } @@ -99,10 +116,16 @@ func TestSendFiles(t *testing.T) { streamer := &stubStreamReadWriter{readStreamErr: tc.readStreamErr} stream := &stubSendFilesStream{sendErr: tc.sendErr} - transfer := New(logger.NewTest(t), streamer, false) + transfer := &FileTransferer{ + log: logger.NewTest(t), + streamer: streamer, + showProgress: false, + } if tc.files != nil { - transfer.SetFiles(*tc.files) + transfer.files = *tc.files } + transfer.receiveFinished.Store(tc.receiveFinished) + err := transfer.SendFiles(stream) if tc.wantErr { @@ -236,7 +259,7 @@ func TestRecvFiles(t *testing.T) { transfer.receiveStarted = true } if tc.recvAlreadyFinished { - transfer.receiveFinished = true + transfer.receiveFinished.Store(true) } err := transfer.RecvFiles(stream) @@ -290,34 +313,11 @@ func TestGetSetFiles(t *testing.T) { } gotFiles := transfer.GetFiles() assert.Equal(tc.wantFiles, gotFiles) - assert.Equal(tc.setFiles != nil, transfer.receiveFinished) + assert.Equal(tc.setFiles != nil, transfer.receiveFinished.Load()) }) } } -func TestCanSend(t *testing.T) { - assert := assert.New(t) - - streamer := &stubStreamReadWriter{} - stream := &stubRecvFilesStream{recvErr: io.EOF} - transfer := New(logger.NewTest(t), streamer, false) - assert.False(transfer.CanSend()) - - // manual set - transfer.SetFiles(nil) - assert.True(transfer.CanSend()) - - // reset - transfer.receiveStarted = false - transfer.receiveFinished = false - transfer.files = nil - assert.False(transfer.CanSend()) - - // receive files (empty) - assert.NoError(transfer.RecvFiles(stream)) - assert.True(transfer.CanSend()) -} - func TestConcurrency(t *testing.T) { ft := New(logger.NewTest(t), &stubStreamReadWriter{}, false) @@ -337,10 +337,6 @@ func TestConcurrency(t *testing.T) { ft.SetFiles([]FileStat{{SourcePath: "file", TargetPath: "file", Mode: 0o644}}) } - canSend := func() { - _ = ft.CanSend() - } - go sendFiles() go sendFiles() go sendFiles() @@ -357,10 +353,6 @@ func TestConcurrency(t *testing.T) { go setFiles() go setFiles() go setFiles() - go canSend() - go canSend() - go canSend() - go canSend() } type stubStreamReadWriter struct {