diff --git a/bosh/build_client.go b/bosh/build_client.go index 289ee1c2e..a0db5d62d 100644 --- a/bosh/build_client.go +++ b/bosh/build_client.go @@ -2,6 +2,7 @@ package bosh import ( "github.com/cloudfoundry-incubator/bosh-backup-and-restore/instance" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ssh" "github.com/cloudfoundry/bosh-cli/v7/director" "github.com/pkg/errors" @@ -10,7 +11,7 @@ import ( boshlog "github.com/cloudfoundry/bosh-utils/logger" ) -func BuildClient(targetUrl, username, password, caCert, bbrVersion string, logger boshlog.Logger) (Client, error) { +func BuildClient(targetUrl, username, password, caCert, bbrVersion string, rateLimiter ratelimiter.RateLimiter, logger boshlog.Logger) (Client, error) { var client Client factoryConfig, err := director.NewConfigFromURL(targetUrl) @@ -44,7 +45,7 @@ func BuildClient(targetUrl, username, password, caCert, bbrVersion string, logge return client, errors.Wrap(err, "error building bosh director client") } - return NewClient(boshDirector, director.NewSSHOpts, ssh.NewSshRemoteRunner, logger, instance.NewJobFinder(bbrVersion, logger), NewBoshManifestQuerier), nil + return NewClient(boshDirector, director.NewSSHOpts, ssh.NewSshRemoteRunner, rateLimiter, logger, instance.NewJobFinder(bbrVersion, logger), NewBoshManifestQuerier), nil } func getDirectorInfo(directorFactory director.Factory, factoryConfig director.FactoryConfig) (director.Info, error) { diff --git a/bosh/build_client_test.go b/bosh/build_client_test.go index be463046a..55d91fd3d 100644 --- a/bosh/build_client_test.go +++ b/bosh/build_client_test.go @@ -8,6 +8,7 @@ import ( "github.com/cloudfoundry-incubator/bosh-backup-and-restore/internal/cf-webmock/mockbosh" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/internal/cf-webmock/mockhttp" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/internal/cf-webmock/mockuaa" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" boshlog "github.com/cloudfoundry/bosh-utils/logger" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -50,7 +51,7 @@ var _ = Describe("BuildClient", func() { mockbosh.Manifest(deploymentName).RespondsWith([]byte("manifest contents")), ) - client, err := BuildClient(director.URL, username, password, caCert, bbrVersion, logger) + client, err := BuildClient(director.URL, username, password, caCert, bbrVersion, ratelimiter.NoOpRateLimiter{}, logger) Expect(err).NotTo(HaveOccurred()) manifest, err := client.GetManifest(deploymentName) @@ -75,7 +76,7 @@ var _ = Describe("BuildClient", func() { mockbosh.Manifest(deploymentName).RespondsWith([]byte("manifest contents")), ) - client, err := BuildClient(director.URL, username, password, caCert, bbrVersion, logger) + client, err := BuildClient(director.URL, username, password, caCert, bbrVersion, ratelimiter.NoOpRateLimiter{}, logger) Expect(err).NotTo(HaveOccurred()) manifest, err := client.GetManifest(deploymentName) @@ -90,7 +91,7 @@ var _ = Describe("BuildClient", func() { director.VerifyAndMock( mockbosh.Info().WithAuthTypeUAA(""), ) - _, err := BuildClient(director.URL, username, password, caCert, bbrVersion, logger) + _, err := BuildClient(director.URL, username, password, caCert, bbrVersion, ratelimiter.NoOpRateLimiter{}, logger) Expect(err).To(MatchError(ContainSubstring("invalid UAA URL"))) @@ -103,7 +104,7 @@ var _ = Describe("BuildClient", func() { caCertPath := "-----BEGIN" basicAuthDirectorURL := director.URL - _, err := BuildClient(basicAuthDirectorURL, username, password, caCertPath, bbrVersion, logger) + _, err := BuildClient(basicAuthDirectorURL, username, password, caCertPath, bbrVersion, ratelimiter.NoOpRateLimiter{}, logger) Expect(err).To(MatchError(ContainSubstring("Missing PEM block"))) }) @@ -113,7 +114,7 @@ var _ = Describe("BuildClient", func() { caCertPath := "" basicAuthDirectorURL := "" - _, err := BuildClient(basicAuthDirectorURL, username, password, caCertPath, bbrVersion, logger) + _, err := BuildClient(basicAuthDirectorURL, username, password, caCertPath, bbrVersion, ratelimiter.NoOpRateLimiter{}, logger) Expect(err).To(MatchError(ContainSubstring("invalid bosh URL"))) }) @@ -125,7 +126,7 @@ var _ = Describe("BuildClient", func() { mockbosh.Info().Fails("fooo!"), ) - _, err := BuildClient(director.URL, username, password, caCert, bbrVersion, logger) + _, err := BuildClient(director.URL, username, password, caCert, bbrVersion, ratelimiter.NoOpRateLimiter{}, logger) Expect(err).To(MatchError(ContainSubstring("bosh director unreachable or unhealthy"))) }) }) diff --git a/bosh/client.go b/bosh/client.go index c9c06071e..d74b0e181 100644 --- a/bosh/client.go +++ b/bosh/client.go @@ -5,6 +5,7 @@ import ( "github.com/cloudfoundry-incubator/bosh-backup-and-restore/instance" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orchestrator" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ssh" "github.com/cloudfoundry/bosh-cli/v7/director" "github.com/cloudfoundry/bosh-utils/uuid" @@ -25,6 +26,7 @@ type BoshClient interface { func NewClient(boshDirector director.Director, sshOptsGenerator ssh.SSHOptsGenerator, remoteRunnerFactory ssh.RemoteRunnerFactory, + rateLimiter ratelimiter.RateLimiter, logger Logger, jobFinder instance.JobFinder, manifestQuerierCreator instance.ManifestQuerierCreator) Client { @@ -32,6 +34,7 @@ func NewClient(boshDirector director.Director, Director: boshDirector, SSHOptsGenerator: sshOptsGenerator, RemoteRunnerFactory: remoteRunnerFactory, + RateLimiter: rateLimiter, Logger: logger, jobFinder: jobFinder, manifestQuerierCreator: manifestQuerierCreator, @@ -42,6 +45,7 @@ type Client struct { director.Director ssh.SSHOptsGenerator ssh.RemoteRunnerFactory + ratelimiter.RateLimiter Logger jobFinder instance.JobFinder manifestQuerierCreator instance.ManifestQuerierCreator @@ -112,7 +116,7 @@ func (c Client) FindInstances(deploymentName string) ([]orchestrator.Instance, e return nil, errors.Wrap(err, "ssh.NewConnection.ParseAuthorizedKey failed") } - remoteRunner, err := c.RemoteRunnerFactory(host.Host, host.Username, privateKey, gossh.FixedHostKey(hostPublicKey), supportedEncryptionAlgorithms(hostPublicKey), c.Logger) + remoteRunner, err := c.RemoteRunnerFactory(host.Host, host.Username, privateKey, gossh.FixedHostKey(hostPublicKey), supportedEncryptionAlgorithms(hostPublicKey), c.RateLimiter, c.Logger) if err != nil { cleanupAlreadyMadeConnections(deployment, slugs, sshOpts) return nil, errors.Wrap(err, "failed to connect using ssh") diff --git a/bosh/client_test.go b/bosh/client_test.go index 0f5802a99..6a5736446 100644 --- a/bosh/client_test.go +++ b/bosh/client_test.go @@ -12,6 +12,7 @@ import ( "github.com/cloudfoundry-incubator/bosh-backup-and-restore/instance" instancefakes "github.com/cloudfoundry-incubator/bosh-backup-and-restore/instance/fakes" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orchestrator" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ssh" sshfakes "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ssh/fakes" "github.com/cloudfoundry/bosh-cli/v7/director" @@ -43,7 +44,7 @@ var _ = Describe("Director", func() { var b bosh.BoshClient JustBeforeEach(func() { - b = bosh.NewClient(boshDirector, optsGenerator.Spy, remoteRunnerFactory.Spy, boshLogger, fakeJobFinder, manifestQuerierCreator.Spy) + b = bosh.NewClient(boshDirector, optsGenerator.Spy, remoteRunnerFactory.Spy, ratelimiter.NoOpRateLimiter{}, boshLogger, fakeJobFinder, manifestQuerierCreator.Spy) }) BeforeEach(func() { @@ -166,7 +167,7 @@ var _ = Describe("Director", func() { It("creates a remote runner for each host", func() { Expect(remoteRunnerFactory.CallCount()).To(Equal(1)) - host, username, privateKey, _, hostPublicKeyAlgorithm, logger := remoteRunnerFactory.ArgsForCall(0) + host, username, privateKey, _, hostPublicKeyAlgorithm, _, logger := remoteRunnerFactory.ArgsForCall(0) Expect(host).To(Equal("10.0.0.0")) Expect(username).To(Equal("username")) Expect(privateKey).To(Equal("private_key")) @@ -195,7 +196,7 @@ var _ = Describe("Director", func() { It("uses the specified port", func() { Expect(remoteRunnerFactory.CallCount()).To(Equal(1)) - host, username, privateKey, _, hostPublicKeyAlgorithm, logger := remoteRunnerFactory.ArgsForCall(0) + host, username, privateKey, _, hostPublicKeyAlgorithm, _, logger := remoteRunnerFactory.ArgsForCall(0) Expect(host).To(Equal("10.0.0.0:3457")) Expect(username).To(Equal("username")) Expect(privateKey).To(Equal("private_key")) @@ -328,14 +329,14 @@ var _ = Describe("Director", func() { It("creates a remote runner for each host", func() { Expect(remoteRunnerFactory.CallCount()).To(Equal(2)) - host, username, privateKey, _, hostPublicKeyAlgorithm, logger := remoteRunnerFactory.ArgsForCall(0) + host, username, privateKey, _, hostPublicKeyAlgorithm, _, logger := remoteRunnerFactory.ArgsForCall(0) Expect(host).To(Equal("10.0.0.1")) Expect(username).To(Equal("username")) Expect(privateKey).To(Equal("private_key")) Expect(hostPublicKeyAlgorithm).To(Equal(hostKeyAlgorithmRSA)) Expect(logger).To(Equal(boshLogger)) - host, username, privateKey, _, hostPublicKeyAlgorithm, logger = remoteRunnerFactory.ArgsForCall(1) + host, username, privateKey, _, hostPublicKeyAlgorithm, _, logger = remoteRunnerFactory.ArgsForCall(1) Expect(host).To(Equal("10.0.0.2")) Expect(username).To(Equal("username")) Expect(privateKey).To(Equal("private_key")) @@ -621,21 +622,21 @@ var _ = Describe("Director", func() { It("creates a remote runner for each host that has scripts, and the first instance of each group that doesn't", func() { Expect(remoteRunnerFactory.CallCount()).To(Equal(3)) - host, username, privateKey, _, hostPublicKeyAlgorithm, logger := remoteRunnerFactory.ArgsForCall(0) + host, username, privateKey, _, hostPublicKeyAlgorithm, _, logger := remoteRunnerFactory.ArgsForCall(0) Expect(host).To(Equal("10.0.0.1")) Expect(username).To(Equal("username")) Expect(privateKey).To(Equal("private_key")) Expect(hostPublicKeyAlgorithm).To(Equal(hostKeyAlgorithmRSA)) Expect(logger).To(Equal(boshLogger)) - host, username, privateKey, _, hostPublicKeyAlgorithm, logger = remoteRunnerFactory.ArgsForCall(1) + host, username, privateKey, _, hostPublicKeyAlgorithm, _, logger = remoteRunnerFactory.ArgsForCall(1) Expect(host).To(Equal("10.0.0.3")) Expect(username).To(Equal("username")) Expect(privateKey).To(Equal("private_key")) Expect(hostPublicKeyAlgorithm).To(Equal(hostKeyAlgorithmRSA)) Expect(logger).To(Equal(boshLogger)) - host, username, privateKey, _, hostPublicKeyAlgorithm, logger = remoteRunnerFactory.ArgsForCall(2) + host, username, privateKey, _, hostPublicKeyAlgorithm, _, logger = remoteRunnerFactory.ArgsForCall(2) Expect(host).To(Equal("10.0.0.4")) Expect(username).To(Equal("username")) Expect(privateKey).To(Equal("private_key")) @@ -689,7 +690,7 @@ var _ = Describe("Director", func() { It("uses the ECDSA algorithm to create its remote runners", func() { Expect(remoteRunnerFactory.CallCount()).To(Equal(1)) - _, _, _, _, hostPublicKeyAlgorithm, _ := remoteRunnerFactory.ArgsForCall(0) + _, _, _, _, hostPublicKeyAlgorithm, _, _ := remoteRunnerFactory.ArgsForCall(0) Expect(hostPublicKeyAlgorithm).To(Equal(hostKeyAlgorithmECDSA)) }) @@ -995,7 +996,7 @@ var _ = Describe("Director", func() { }}, nil } - remoteRunnerFactory.Stub = func(host, user, privateKey string, publicKeyCallback gossh.HostKeyCallback, publicKeyAlgorithm []string, logger ssh.Logger) (ssh.RemoteRunner, error) { + remoteRunnerFactory.Stub = func(host, user, privateKey string, publicKeyCallback gossh.HostKeyCallback, publicKeyAlgorithm []string, rateLimiter ratelimiter.RateLimiter, logger ssh.Logger) (ssh.RemoteRunner, error) { if host == "10.0.0.0_job1" { return remoteRunner, nil } diff --git a/cli/command/all_deployments.go b/cli/command/all_deployments.go index bea166b1c..687623547 100644 --- a/cli/command/all_deployments.go +++ b/cli/command/all_deployments.go @@ -15,6 +15,7 @@ import ( "github.com/cloudfoundry-incubator/bosh-backup-and-restore/executor/deployment" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/factory" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orchestrator" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" "github.com/urfave/cli" ) @@ -128,6 +129,18 @@ func getDeploymentParams(c *cli.Context) (string, string, string, string, string return username, password, target, caCert, bbrVersion, debug, deployment, allDeployments } +func getConnectionRateLimiter(c *cli.Context) (ratelimiter.RateLimiter, error) { + enabled := c.Parent().Bool("rate-limiting") + maxConnections := c.Parent().Int("rate-limiting-max-connections") + duration := c.Parent().String("rate-limiting-duration") + + if enabled { + return ratelimiter.NewConnectionRateLimiter(maxConnections, duration) + } + return ratelimiter.NewNoOpRateLimiter(), nil + +} + type DeploymentExecutable struct { action ActionFunc name string diff --git a/cli/command/deployment_backup.go b/cli/command/deployment_backup.go index 313808b17..c0e9e396c 100644 --- a/cli/command/deployment_backup.go +++ b/cli/command/deployment_backup.go @@ -7,6 +7,7 @@ import ( "github.com/cloudfoundry-incubator/bosh-backup-and-restore/executor/deployment" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/factory" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orchestrator" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" "github.com/urfave/cli" ) @@ -50,17 +51,23 @@ func (d DeploymentBackupCommand) Action(c *cli.Context) error { unsafeLockFree := c.Bool("unsafe-lock-free") artifactPath := c.String("artifact-path") + rateLimiter, err := getConnectionRateLimiter(c) + + if err != nil { + return err + } + if allDeployments { if unsafeLockFree { return processError(orchestrator.NewError(fmt.Errorf("Cannot use the --unsafe-lock-free flag in conjunction with the --all-deployments flag"))) } - return backupAll(target, username, password, caCert, artifactPath, withManifest, bbrVersion, debug) + return backupAll(target, username, password, caCert, artifactPath, withManifest, bbrVersion, debug, rateLimiter) } - return backupSingleDeployment(deployment, target, username, password, caCert, artifactPath, withManifest, bbrVersion, unsafeLockFree, debug) + return backupSingleDeployment(deployment, target, username, password, caCert, artifactPath, withManifest, bbrVersion, unsafeLockFree, debug, rateLimiter) } -func backupAll(target, username, password, caCert, artifactPath string, withManifest bool, bbrVersion string, debug bool) error { +func backupAll(target, username, password, caCert, artifactPath string, withManifest bool, bbrVersion string, debug bool, rateLimiter ratelimiter.RateLimiter) error { backupAction := func(deploymentName string) orchestrator.Error { timestamp := time.Now().UTC().Format(artifactTimeStampFormat) logFilePath, buffer, logger, logErr := createLogger(timestamp, artifactPath, deploymentName, debug) @@ -76,6 +83,7 @@ func backupAll(target, username, password, caCert, artifactPath string, withMani withManifest, false, bbrVersion, + rateLimiter, logger, timestamp, ) @@ -106,7 +114,7 @@ func backupAll(target, username, password, caCert, artifactPath string, withMani fmt.Println("Starting backup...") logger, _ := factory.BuildBoshLoggerWithCustomBuffer(debug) - boshClient, err := factory.BuildBoshClient(target, username, password, caCert, bbrVersion, logger) + boshClient, err := factory.BuildBoshClient(target, username, password, caCert, bbrVersion, rateLimiter, logger) if err != nil { return processError(orchestrator.NewError(err)) } @@ -119,11 +127,11 @@ func backupAll(target, username, password, caCert, artifactPath string, withMani deployment.NewParallelExecutor()) } -func backupSingleDeployment(deployment, target, username, password, caCert, artifactPath string, withManifest bool, bbrVersion string, unsafeLockFree, debug bool) error { +func backupSingleDeployment(deployment, target, username, password, caCert, artifactPath string, withManifest bool, bbrVersion string, unsafeLockFree, debug bool, rateLimiter ratelimiter.RateLimiter) error { logger := factory.BuildBoshLogger(debug) timeStamp := time.Now().UTC().Format(artifactTimeStampFormat) - backuper, err := factory.BuildDeploymentBackuper(target, username, password, caCert, withManifest, unsafeLockFree, bbrVersion, logger, timeStamp) + backuper, err := factory.BuildDeploymentBackuper(target, username, password, caCert, withManifest, unsafeLockFree, bbrVersion, rateLimiter, logger, timeStamp) if err != nil { return processError(orchestrator.NewError(err)) } diff --git a/cli/command/deployment_backup_cleanup.go b/cli/command/deployment_backup_cleanup.go index d0a0736f6..1ec09bee6 100644 --- a/cli/command/deployment_backup_cleanup.go +++ b/cli/command/deployment_backup_cleanup.go @@ -5,6 +5,7 @@ import ( "time" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/executor/deployment" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/factory" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orchestrator" @@ -31,6 +32,12 @@ func (d DeploymentBackupCleanupCommand) Action(c *cli.Context) error { username, password, target, caCert, bbrVersion, debug, deployment, allDeployments := getDeploymentParams(c) + rateLimiter, err := getConnectionRateLimiter(c) + + if err != nil { + return err + } + if !allDeployments { logger := factory.BuildBoshLogger(debug) @@ -40,6 +47,7 @@ func (d DeploymentBackupCleanupCommand) Action(c *cli.Context) error { password, caCert, c.App.Version, + rateLimiter, logger, ) if err != nil { @@ -50,10 +58,10 @@ func (d DeploymentBackupCleanupCommand) Action(c *cli.Context) error { return processError(cleanupErr) } - return cleanupAllDeployments(target, username, password, caCert, bbrVersion, debug) + return cleanupAllDeployments(target, username, password, caCert, bbrVersion, debug, rateLimiter) } -func cleanupAllDeployments(target, username, password, caCert, bbrVersion string, debug bool) error { +func cleanupAllDeployments(target, username, password, caCert, bbrVersion string, debug bool, rateLimiter ratelimiter.RateLimiter) error { cleanupAction := func(deploymentName string) orchestrator.Error { timestamp := time.Now().UTC().Format(artifactTimeStampFormat) logFilePath, buffer, logger, logErr := createLogger(timestamp, "", deploymentName, debug) @@ -67,6 +75,7 @@ func cleanupAllDeployments(target, username, password, caCert, bbrVersion string password, caCert, bbrVersion, + rateLimiter, logger, ) @@ -93,7 +102,7 @@ func cleanupAllDeployments(target, username, password, caCert, bbrVersion string logger, _ := factory.BuildBoshLoggerWithCustomBuffer(debug) - boshClient, err := factory.BuildBoshClient(target, username, password, caCert, bbrVersion, logger) + boshClient, err := factory.BuildBoshClient(target, username, password, caCert, bbrVersion, rateLimiter, logger) if err != nil { return err } diff --git a/cli/command/deployment_pre_backup_check.go b/cli/command/deployment_pre_backup_check.go index c7816595b..7c8916001 100644 --- a/cli/command/deployment_pre_backup_check.go +++ b/cli/command/deployment_pre_backup_check.go @@ -30,13 +30,20 @@ func (d DeploymentPreBackupCheck) Cli() cli.Command { func (d DeploymentPreBackupCheck) Action(c *cli.Context) error { username, password, target, caCert, bbrVersion, debug, deployment, allDeployments := getDeploymentParams(c) + + rateLimiter, err := getConnectionRateLimiter(c) + + if err != nil { + return err + } + var logger logger.Logger if allDeployments { logger, _ = factory.BuildBoshLoggerWithCustomBuffer(debug) } else { logger = factory.BuildBoshLogger(debug) } - boshClient, err := factory.BuildBoshClient(target, username, password, caCert, bbrVersion, logger) + boshClient, err := factory.BuildBoshClient(target, username, password, caCert, bbrVersion, rateLimiter, logger) if err != nil { return processError(orchestrator.NewError(err)) } diff --git a/cli/command/deployment_restore.go b/cli/command/deployment_restore.go index 3d91e4bc0..e8ebd0e1d 100644 --- a/cli/command/deployment_restore.go +++ b/cli/command/deployment_restore.go @@ -37,12 +37,19 @@ func (d DeploymentRestoreCommand) Action(c *cli.Context) error { deployment := c.Parent().String("deployment") artifactPath := c.String("artifact-path") + rateLimiter, err := getConnectionRateLimiter(c) + + if err != nil { + return err + } + restorer, err := factory.BuildDeploymentRestorer(c.Parent().String("target"), c.Parent().String("username"), c.Parent().String("password"), c.Parent().String("ca-cert"), c.App.Version, - c.GlobalBool("debug")) + c.GlobalBool("debug"), + rateLimiter) if err != nil { return processError(orchestrator.NewError(err)) diff --git a/cli/command/deployment_restore_cleanup.go b/cli/command/deployment_restore_cleanup.go index 8b7f261b3..b330269bf 100644 --- a/cli/command/deployment_restore_cleanup.go +++ b/cli/command/deployment_restore_cleanup.go @@ -24,13 +24,20 @@ func (d DeploymentRestoreCleanupCommand) Cli() cli.Command { func (d DeploymentRestoreCleanupCommand) Action(c *cli.Context) error { trapSigint(true) + rateLimiter, err := getConnectionRateLimiter(c) + + if err != nil { + return err + } + cleaner, err := factory.BuildDeploymentRestoreCleanuper(c.Parent().String("target"), c.Parent().String("username"), c.Parent().String("password"), c.Parent().String("ca-cert"), c.App.Version, c.Bool("with-manifest"), - c.GlobalBool("debug")) + c.GlobalBool("debug"), + rateLimiter) if err != nil { return processError(orchestrator.NewError(err)) diff --git a/cli/command/director_backup.go b/cli/command/director_backup.go index 47e04b11c..8804e9cad 100644 --- a/cli/command/director_backup.go +++ b/cli/command/director_backup.go @@ -36,12 +36,19 @@ func (checkCommand DirectorBackupCommand) Action(c *cli.Context) error { directorName := extractNameFromAddress(c.Parent().String("host")) timeStamp := time.Now().UTC().Format(artifactTimeStampFormat) + rateLimiter, err := getConnectionRateLimiter(c) + + if err != nil { + return err + } + backuper := factory.BuildDirectorBackuper( c.Parent().String("host"), c.Parent().String("username"), c.Parent().String("private-key-path"), c.App.Version, c.GlobalBool("debug"), + rateLimiter, timeStamp) backupErr := backuper.Backup(directorName, c.String("artifact-path")) diff --git a/cli/command/director_backup_cleanup.go b/cli/command/director_backup_cleanup.go index da68dc758..41264825f 100644 --- a/cli/command/director_backup_cleanup.go +++ b/cli/command/director_backup_cleanup.go @@ -24,11 +24,18 @@ func (d DirectorBackupCleanupCommand) Action(c *cli.Context) error { directorName := extractNameFromAddress(c.Parent().String("host")) + rateLimiter, err := getConnectionRateLimiter(c) + + if err != nil { + return err + } + cleaner := factory.BuildDirectorBackupCleaner(c.Parent().String("host"), c.Parent().String("username"), c.Parent().String("private-key-path"), c.App.Version, c.GlobalBool("debug"), + rateLimiter, ) cleanupErr := cleaner.Cleanup(directorName) diff --git a/cli/command/director_pre_backup_check.go b/cli/command/director_pre_backup_check.go index b04560b53..d32064e31 100644 --- a/cli/command/director_pre_backup_check.go +++ b/cli/command/director_pre_backup_check.go @@ -26,24 +26,31 @@ func NewDirectorPreBackupCheckCommand() DirectorPreBackupCheckCommand { func (checkCommand DirectorPreBackupCheckCommand) Action(c *cli.Context) error { directorName := extractNameFromAddress(c.Parent().String("host")) + rateLimiter, err := getConnectionRateLimiter(c) + + if err != nil { + return err + } + backupChecker := factory.BuildDirectorBackupChecker( c.Parent().String("host"), c.Parent().String("username"), c.Parent().String("private-key-path"), c.App.Version, c.GlobalBool("debug"), + rateLimiter, ) - err := backupChecker.Check(directorName) + orchErr := backupChecker.Check(directorName) if err != nil { fmt.Printf("Director cannot be backed up.\n") - if err.ContainsArtifactDirError() { - return processErrorWithFooter(err, backupCleanupAdvisedNotice) + if orchErr.ContainsArtifactDirError() { + return processErrorWithFooter(orchErr, backupCleanupAdvisedNotice) } - return processError(err) + return processError(orchErr) } fmt.Printf("Director can be backed up.\n") diff --git a/cli/command/director_restore.go b/cli/command/director_restore.go index 189353bec..64f4dfa03 100644 --- a/cli/command/director_restore.go +++ b/cli/command/director_restore.go @@ -38,12 +38,19 @@ func (cmd DirectorRestoreCommand) Action(c *cli.Context) error { directorName := extractNameFromAddress(c.Parent().String("host")) artifactPath := c.String("artifact-path") + rateLimiter, err := getConnectionRateLimiter(c) + + if err != nil { + return err + } + restorer := factory.BuildDirectorRestorer( c.Parent().String("host"), c.Parent().String("username"), c.Parent().String("private-key-path"), c.App.Version, c.GlobalBool("debug"), + rateLimiter, ) restoreErr := restorer.Restore(directorName, artifactPath) diff --git a/cli/command/director_restore_cleanup.go b/cli/command/director_restore_cleanup.go index 0e1d2b28a..591896ce8 100644 --- a/cli/command/director_restore_cleanup.go +++ b/cli/command/director_restore_cleanup.go @@ -24,12 +24,19 @@ func (d DirectorRestoreCleanupCommand) Action(c *cli.Context) error { directorName := extractNameFromAddress(c.Parent().String("host")) + rateLimiter, err := getConnectionRateLimiter(c) + + if err != nil { + return err + } + cleaner := factory.BuildDirectorRestoreCleaner( c.Parent().String("host"), c.Parent().String("username"), c.Parent().String("private-key-path"), c.App.Version, c.GlobalBool("debug"), + rateLimiter, ) cleanupErr := cleaner.Cleanup(directorName) diff --git a/cmd/bbr/main.go b/cmd/bbr/main.go index 29c00a8df..7aeaef0c7 100644 --- a/cmd/bbr/main.go +++ b/cmd/bbr/main.go @@ -143,6 +143,20 @@ func availableDeploymentFlags() []cli.Flag { Name: "all-deployments", Usage: "Run command for all deployments. Omit if '--deployment' is provided. Currently only supported for: pre-backup-check, backup and backup-cleanup", }, + cli.BoolFlag{ + Name: "rate-limiting", + Usage: "Enable ssh connection rate limiting", + }, + cli.IntFlag{ + Name: "rate-limiting-max-connections", + Usage: "Set the maximum amount of ssh connections that can be opened in configurable duration window (used with --rate-limiting)", + Value: 20, + }, + cli.StringFlag{ + Name: "rate-limiting-duration", + Usage: "Set the duration window (example: 20s or 1m) (used with --rate-limiting)", + Value: "60s", + }, } } @@ -167,5 +181,19 @@ func availableDirectorFlags() []cli.Flag { Name: "debug", Usage: "Enable debug logs", }, + cli.BoolFlag{ + Name: "rate-limiting", + Usage: "Enable ssh connection rate limiting", + }, + cli.IntFlag{ + Name: "rate-limiting-max-connections", + Usage: "Set the maximum amount of ssh connections that can be opened in configurable duration window (used with --rate-limiting)", + Value: 20, + }, + cli.StringFlag{ + Name: "rate-limiting-duration", + Usage: "Set the duration window (example: 20s or 1m) (used with --rate-limiting)", + Value: "60s", + }, } } diff --git a/factory/bosh_deployment_manager.go b/factory/bosh_deployment_manager.go index 5256de34a..0f9608709 100644 --- a/factory/bosh_deployment_manager.go +++ b/factory/bosh_deployment_manager.go @@ -2,13 +2,14 @@ package factory import ( "github.com/cloudfoundry-incubator/bosh-backup-and-restore/bosh" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" boshlog "github.com/cloudfoundry/bosh-utils/logger" boshcmd "github.com/cloudfoundry/bosh-cli/v7/cmd/opts" boshsys "github.com/cloudfoundry/bosh-utils/system" ) -func BuildBoshClient(targetUrl, username, password, caCertPathOrValue, bbrVersion string, logger boshlog.Logger) (bosh.Client, error) { +func BuildBoshClient(targetUrl, username, password, caCertPathOrValue, bbrVersion string, rateLimiter ratelimiter.RateLimiter, logger boshlog.Logger) (bosh.Client, error) { var boshClient bosh.Client var err error fs := boshsys.NewOsFileSystem(logger) @@ -20,7 +21,7 @@ func BuildBoshClient(targetUrl, username, password, caCertPathOrValue, bbrVersio return boshClient, err } - boshClient, err = bosh.BuildClient(targetUrl, username, password, caCertArg.Content, bbrVersion, logger) + boshClient, err = bosh.BuildClient(targetUrl, username, password, caCertArg.Content, bbrVersion, rateLimiter, logger) if err != nil { return boshClient, err } diff --git a/factory/deployment_backup_cleanuper.go b/factory/deployment_backup_cleanuper.go index 6bae1a769..3bcf356fd 100644 --- a/factory/deployment_backup_cleanuper.go +++ b/factory/deployment_backup_cleanuper.go @@ -5,6 +5,7 @@ import ( "github.com/cloudfoundry-incubator/bosh-backup-and-restore/executor" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orchestrator" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orderer" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" "github.com/cloudfoundry/bosh-utils/logger" ) @@ -14,10 +15,11 @@ func BuildDeploymentBackupCleanuper( password, caCert, bbrVersion string, + rateLimiter ratelimiter.RateLimiter, logger logger.Logger, ) (*orchestrator.BackupCleaner, error) { - boshClient, err := BuildBoshClient(target, username, password, caCert, bbrVersion, logger) + boshClient, err := BuildBoshClient(target, username, password, caCert, bbrVersion, rateLimiter, logger) if err != nil { return nil, err diff --git a/factory/deployment_backuper.go b/factory/deployment_backuper.go index a98dd2869..4a07acaa2 100644 --- a/factory/deployment_backuper.go +++ b/factory/deployment_backuper.go @@ -8,6 +8,7 @@ import ( "github.com/cloudfoundry-incubator/bosh-backup-and-restore/executor" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orchestrator" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orderer" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" boshlog "github.com/cloudfoundry/bosh-utils/logger" ) @@ -19,10 +20,11 @@ func BuildDeploymentBackuper( withManifest bool, unsafeLockFree bool, bbrVersion string, + rateLimiter ratelimiter.RateLimiter, logger boshlog.Logger, timestamp string, ) (*orchestrator.Backuper, error) { - boshClient, err := BuildBoshClient(target, username, password, caCert, bbrVersion, logger) + boshClient, err := BuildBoshClient(target, username, password, caCert, bbrVersion, rateLimiter, logger) if err != nil { return nil, err } diff --git a/factory/deployment_restore_cleanuper.go b/factory/deployment_restore_cleanuper.go index ca459a5da..b01cc8d3c 100644 --- a/factory/deployment_restore_cleanuper.go +++ b/factory/deployment_restore_cleanuper.go @@ -5,6 +5,7 @@ import ( "github.com/cloudfoundry-incubator/bosh-backup-and-restore/executor" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orchestrator" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orderer" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" ) func BuildDeploymentRestoreCleanuper(target, @@ -13,7 +14,8 @@ func BuildDeploymentRestoreCleanuper(target, caCert, bbrVersion string, withManifest, - isDebug bool) (*orchestrator.RestoreCleaner, error) { + isDebug bool, + rateLimiter ratelimiter.RateLimiter) (*orchestrator.RestoreCleaner, error) { logger := BuildLogger(isDebug) @@ -23,6 +25,7 @@ func BuildDeploymentRestoreCleanuper(target, password, caCert, bbrVersion, + rateLimiter, logger, ) diff --git a/factory/deployment_restorer.go b/factory/deployment_restorer.go index 80fb4bad1..e2c2ef91b 100644 --- a/factory/deployment_restorer.go +++ b/factory/deployment_restorer.go @@ -6,9 +6,10 @@ import ( "github.com/cloudfoundry-incubator/bosh-backup-and-restore/executor" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orchestrator" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orderer" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" ) -func BuildDeploymentRestorer(target, username, password, caCert, bbrVersion string, debug bool) (*orchestrator.Restorer, error) { +func BuildDeploymentRestorer(target, username, password, caCert, bbrVersion string, debug bool, rateLimiter ratelimiter.RateLimiter) (*orchestrator.Restorer, error) { logger := BuildLogger(debug) boshClient, err := BuildBoshClient( target, @@ -16,6 +17,7 @@ func BuildDeploymentRestorer(target, username, password, caCert, bbrVersion stri password, caCert, bbrVersion, + rateLimiter, logger, ) if err != nil { diff --git a/factory/director_backup_checker.go b/factory/director_backup_checker.go index 17916081e..05fc2741e 100644 --- a/factory/director_backup_checker.go +++ b/factory/director_backup_checker.go @@ -4,11 +4,12 @@ import ( "github.com/cloudfoundry-incubator/bosh-backup-and-restore/instance" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orchestrator" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orderer" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ssh" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/standalone" ) -func BuildDirectorBackupChecker(host, username, privateKeyPath, bbrVersion string, hasDebug bool) *orchestrator.BackupChecker { +func BuildDirectorBackupChecker(host, username, privateKeyPath, bbrVersion string, hasDebug bool, rateLimiter ratelimiter.RateLimiter) *orchestrator.BackupChecker { logger := BuildLogger(hasDebug) deploymentManager := standalone.NewDeploymentManager(logger, host, @@ -16,6 +17,7 @@ func BuildDirectorBackupChecker(host, username, privateKeyPath, bbrVersion strin privateKeyPath, instance.NewJobFinderOmitMetadataReleases(bbrVersion, logger), ssh.NewSshRemoteRunner, + rateLimiter, ) return orchestrator.NewBackupChecker(logger, deploymentManager, orderer.NewKahnBackupLockOrderer()) diff --git a/factory/director_backup_cleaner.go b/factory/director_backup_cleaner.go index f58d5f6b3..6ec997271 100644 --- a/factory/director_backup_cleaner.go +++ b/factory/director_backup_cleaner.go @@ -5,6 +5,7 @@ import ( "github.com/cloudfoundry-incubator/bosh-backup-and-restore/instance" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orchestrator" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orderer" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ssh" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/standalone" ) @@ -13,7 +14,8 @@ func BuildDirectorBackupCleaner(host, username, privateKeyPath, bbrVersion string, - hasDebug bool) *orchestrator.BackupCleaner { + hasDebug bool, + rateLimiter ratelimiter.RateLimiter) *orchestrator.BackupCleaner { logger := BuildLogger(hasDebug) deploymentManager := standalone.NewDeploymentManager(logger, @@ -22,6 +24,7 @@ func BuildDirectorBackupCleaner(host, privateKeyPath, instance.NewJobFinderOmitMetadataReleases(bbrVersion, logger), ssh.NewSshRemoteRunner, + rateLimiter, ) return orchestrator.NewBackupCleaner(logger, deploymentManager, orderer.NewKahnBackupLockOrderer(), executor.NewParallelExecutor()) diff --git a/factory/director_backuper.go b/factory/director_backuper.go index 8951d76d8..2f8156a7e 100644 --- a/factory/director_backuper.go +++ b/factory/director_backuper.go @@ -8,11 +8,12 @@ import ( "github.com/cloudfoundry-incubator/bosh-backup-and-restore/instance" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orchestrator" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orderer" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ssh" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/standalone" ) -func BuildDirectorBackuper(host, username, privateKeyPath, bbrVersion string, hasDebug bool, timeStamp string) *orchestrator.Backuper { +func BuildDirectorBackuper(host, username, privateKeyPath, bbrVersion string, hasDebug bool, rateLimiter ratelimiter.RateLimiter, timeStamp string) *orchestrator.Backuper { logger := BuildLogger(hasDebug) deploymentManager := standalone.NewDeploymentManager(logger, host, @@ -20,6 +21,7 @@ func BuildDirectorBackuper(host, username, privateKeyPath, bbrVersion string, ha privateKeyPath, instance.NewJobFinderOmitMetadataReleases(bbrVersion, logger), ssh.NewSshRemoteRunner, + rateLimiter, ) execr := executor.NewParallelExecutor() diff --git a/factory/director_restore_cleaner.go b/factory/director_restore_cleaner.go index f7c3745e1..5c2e840dc 100644 --- a/factory/director_restore_cleaner.go +++ b/factory/director_restore_cleaner.go @@ -5,6 +5,7 @@ import ( "github.com/cloudfoundry-incubator/bosh-backup-and-restore/instance" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orchestrator" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orderer" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ssh" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/standalone" ) @@ -13,7 +14,8 @@ func BuildDirectorRestoreCleaner(host, username, privateKeyPath, bbrVersion string, - hasDebug bool) *orchestrator.RestoreCleaner { + hasDebug bool, + rateLimiter ratelimiter.RateLimiter) *orchestrator.RestoreCleaner { logger := BuildLogger(hasDebug) @@ -23,6 +25,7 @@ func BuildDirectorRestoreCleaner(host, privateKeyPath, instance.NewJobFinderOmitMetadataReleases(bbrVersion, logger), ssh.NewSshRemoteRunner, + rateLimiter, ) return orchestrator.NewRestoreCleaner(logger, deploymentManager, orderer.NewKahnRestoreLockOrderer(), executor.NewSerialExecutor()) diff --git a/factory/director_restorer.go b/factory/director_restorer.go index 0ad419368..a7d83386e 100644 --- a/factory/director_restorer.go +++ b/factory/director_restorer.go @@ -6,11 +6,12 @@ import ( "github.com/cloudfoundry-incubator/bosh-backup-and-restore/instance" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orchestrator" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orderer" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ssh" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/standalone" ) -func BuildDirectorRestorer(host, username, privateKeyPath, bbrVersion string, hasDebug bool) *orchestrator.Restorer { +func BuildDirectorRestorer(host, username, privateKeyPath, bbrVersion string, hasDebug bool, rateLimiter ratelimiter.RateLimiter) *orchestrator.Restorer { logger := BuildLogger(hasDebug) deploymentManager := standalone.NewDeploymentManager(logger, host, @@ -18,6 +19,7 @@ func BuildDirectorRestorer(host, username, privateKeyPath, bbrVersion string, ha privateKeyPath, instance.NewJobFinderOmitMetadataReleases(bbrVersion, logger), ssh.NewSshRemoteRunner, + rateLimiter, ) return orchestrator.NewRestorer( diff --git a/ratelimiter/connection_rate_limiter.go b/ratelimiter/connection_rate_limiter.go new file mode 100644 index 000000000..0ca02bb1a --- /dev/null +++ b/ratelimiter/connection_rate_limiter.go @@ -0,0 +1,53 @@ +package ratelimiter + +import ( + "errors" + "fmt" + "time" +) + +type RateLimiter interface { + RateLimit() +} + +type ConnectionRateLimiter struct { + guard chan bool + duration time.Duration +} + +func NewConnectionRateLimiter(maxConnections int, durationString string) (RateLimiter, error) { + + if maxConnections < 1 || maxConnections > 100 { + errorMessage := "max connections cannot be less than 1 or greater than 100" + fmt.Println(errorMessage) + return nil, errors.New(errorMessage) + } + + duration, err := time.ParseDuration(durationString) + + if err != nil { + fmt.Printf("unable to parse rating limit duration: %s\n", err.Error()) + return nil, err + } + + if duration <= 0 || duration > (3600*time.Second) { + errorMessage := "duration cannot be 0 or greater than 3600 seconds" + fmt.Println(errorMessage) + return nil, errors.New(errorMessage) + } + + return &ConnectionRateLimiter{ + guard: make(chan bool, maxConnections), + duration: duration, + }, nil +} + +func (t *ConnectionRateLimiter) RateLimit() { + + t.guard <- true + + go func() { + time.Sleep(t.duration) + <-t.guard + }() +} diff --git a/ratelimiter/connection_rate_limiter_test.go b/ratelimiter/connection_rate_limiter_test.go new file mode 100644 index 000000000..b0888e5e4 --- /dev/null +++ b/ratelimiter/connection_rate_limiter_test.go @@ -0,0 +1,76 @@ +package ratelimiter_test + +import ( + "context" + "time" + + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("ConnectionRateLimiter", func() { + + Describe("RateLimit", func() { + Context("success", func() { + It("rate limits", func(ctx context.Context) { + rateLimiter, err := ratelimiter.NewConnectionRateLimiter(5, "1s") + + Expect(err).To(BeNil()) + + completion := make(chan struct{}, 10) + + for i := 0; i < 10; i++ { + go func() { + rateLimiter.RateLimit() + completion <- struct{}{} + }() + } + + time.Sleep(10 * time.Millisecond) + Expect(completion).To(HaveLen(5)) + + time.Sleep(1 * time.Second) + Expect(completion).To(HaveLen(10)) + + }, SpecTimeout(2*time.Second)) + }) + + Context("failure", func() { + It("throws and error if rate limit is less than 1", func() { + _, err := ratelimiter.NewConnectionRateLimiter(0, "1s") + + Expect(err).NotTo(BeNil()) + Expect(err.Error()).To(ContainSubstring("less than 1")) + }) + + It("throws and error if rate limit is greater than 100", func() { + _, err := ratelimiter.NewConnectionRateLimiter(101, "1s") + + Expect(err).NotTo(BeNil()) + Expect(err.Error()).To(ContainSubstring("greater than 100")) + }) + + It("throws and error if duration is less than 1", func() { + _, err := ratelimiter.NewConnectionRateLimiter(5, "0s") + + Expect(err).NotTo(BeNil()) + Expect(err.Error()).To(ContainSubstring("cannot be 0")) + }) + + It("throws and error if duration is greater than 3600", func() { + _, err := ratelimiter.NewConnectionRateLimiter(5, "3601s") + + Expect(err).NotTo(BeNil()) + Expect(err.Error()).To(ContainSubstring("greater than 3600")) + }) + + It("throws and error if duration is invalid", func() { + _, err := ratelimiter.NewConnectionRateLimiter(5, "1yxz") + + Expect(err).NotTo(BeNil()) + Expect(err.Error()).To(ContainSubstring("duration \"1yxz\"")) + }) + }) + }) +}) diff --git a/ratelimiter/noop_rate_limiter.go b/ratelimiter/noop_rate_limiter.go new file mode 100644 index 000000000..16ba4b8ee --- /dev/null +++ b/ratelimiter/noop_rate_limiter.go @@ -0,0 +1,12 @@ +package ratelimiter + +type NoOpRateLimiter struct { +} + +func NewNoOpRateLimiter() RateLimiter { + return NoOpRateLimiter{} +} + +func (n NoOpRateLimiter) RateLimit() { + +} diff --git a/ratelimiter/ratelimiter_suite_test.go b/ratelimiter/ratelimiter_suite_test.go new file mode 100644 index 000000000..16b4e7141 --- /dev/null +++ b/ratelimiter/ratelimiter_suite_test.go @@ -0,0 +1,13 @@ +package ratelimiter_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestRatelimiter(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Ratelimiter Suite") +} diff --git a/ssh/connection.go b/ssh/connection.go index 8aec77279..06fbcc025 100644 --- a/ssh/connection.go +++ b/ssh/connection.go @@ -15,6 +15,7 @@ import ( "net" "os" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" boshhttp "github.com/cloudfoundry/bosh-utils/httpclient" proxy "github.com/cloudfoundry/socks5-proxy" "github.com/pkg/errors" @@ -40,17 +41,17 @@ type Logger interface { var dialFunc boshhttp.DialContextFunc var dialFuncMutex sync.RWMutex -func NewConnection(hostName, userName, privateKey string, publicKeyCallback ssh.HostKeyCallback, publicKeyAlgorithm []string, logger Logger) (SSHConnection, error) { - return NewConnectionWithServerAliveInterval(hostName, userName, privateKey, publicKeyCallback, publicKeyAlgorithm, 60, logger) +func NewConnection(hostName, userName, privateKey string, publicKeyCallback ssh.HostKeyCallback, publicKeyAlgorithm []string, rateLimiter ratelimiter.RateLimiter, logger Logger) (SSHConnection, error) { + return NewConnectionWithServerAliveInterval(hostName, userName, privateKey, publicKeyCallback, publicKeyAlgorithm, 60, rateLimiter, logger) } -func NewConnectionWithServerAliveInterval(hostName, userName, privateKey string, publicKeyCallback ssh.HostKeyCallback, publicKeyAlgorithm []string, serverAliveInterval time.Duration, logger Logger) (SSHConnection, error) { +func NewConnectionWithServerAliveInterval(hostName, userName, privateKey string, publicKeyCallback ssh.HostKeyCallback, publicKeyAlgorithm []string, serverAliveInterval time.Duration, rateLimiter ratelimiter.RateLimiter, logger Logger) (SSHConnection, error) { parsedPrivateKey, err := ssh.ParsePrivateKey([]byte(privateKey)) if err != nil { return nil, errors.Wrap(err, "ssh.NewConnection.ParsePrivateKey failed") } - conn := Connection{ + conn := &Connection{ host: defaultToSSHPort(hostName), sshConfig: &ssh.ClientConfig{ User: userName, @@ -63,8 +64,8 @@ func NewConnectionWithServerAliveInterval(hostName, userName, privateKey string, logger: logger, serverAliveInterval: serverAliveInterval, dialFunc: createDialContextFunc(), + rateLimiter: rateLimiter, } - return conn, nil } @@ -74,9 +75,10 @@ type Connection struct { logger Logger serverAliveInterval time.Duration dialFunc boshhttp.DialContextFunc + rateLimiter ratelimiter.RateLimiter } -func (c Connection) Run(cmd string) (stdout, stderr []byte, exitCode int, err error) { +func (c *Connection) Run(cmd string) (stdout, stderr []byte, exitCode int, err error) { stdoutBuffer := bytes.NewBuffer([]byte{}) stderr, exitCode, err = c.Stream(cmd, stdoutBuffer) @@ -84,7 +86,7 @@ func (c Connection) Run(cmd string) (stdout, stderr []byte, exitCode int, err er return stdoutBuffer.Bytes(), stderr, exitCode, errors.Wrap(err, "ssh.Run failed") } -func (c Connection) Stream(cmd string, stdoutWriter io.Writer) (stderr []byte, exitCode int, err error) { +func (c *Connection) Stream(cmd string, stdoutWriter io.Writer) (stderr []byte, exitCode int, err error) { errBuffer := bytes.NewBuffer([]byte{}) exitCode, err = c.runInSession(cmd, stdoutWriter, errBuffer, nil) @@ -92,7 +94,7 @@ func (c Connection) Stream(cmd string, stdoutWriter io.Writer) (stderr []byte, e return errBuffer.Bytes(), exitCode, errors.Wrap(err, "ssh.Stream failed") } -func (c Connection) StreamStdin(cmd string, stdinReader io.Reader) (stdout, stderr []byte, exitCode int, err error) { +func (c *Connection) StreamStdin(cmd string, stdinReader io.Reader) (stdout, stderr []byte, exitCode int, err error) { stdoutBuffer := bytes.NewBuffer([]byte{}) stderrBuffer := bytes.NewBuffer([]byte{}) @@ -116,7 +118,9 @@ func (w *sessionClosingOnErrorWriter) Write(data []byte) (int, error) { return n, err } -func (c Connection) newClient() (*ssh.Client, error) { +func (c *Connection) newClient() (*ssh.Client, error) { + c.rateLimiter.RateLimit() + conn, err := c.dialFunc(context.Background(), "tcp", c.host) if err != nil { return nil, err @@ -171,7 +175,7 @@ func buildSSHSessionImpl(client *ssh.Client, stdin io.Reader, stdout, stderr io. var buildSSHSession = buildSSHSessionImpl -func (c Connection) runInSession(cmd string, stdout, stderr io.Writer, stdin io.Reader) (int, error) { +func (c *Connection) runInSession(cmd string, stdout, stderr io.Writer, stdin io.Reader) (int, error) { client, err := c.newClient() if err != nil { return -1, errors.Wrap(err, "ssh.Dial failed") @@ -216,7 +220,7 @@ func (c Connection) runInSession(cmd string, stdout, stderr io.Writer, stdin io. return 0, nil } -func (c Connection) startKeepAliveLoop(session SSHSession) chan struct{} { +func (c *Connection) startKeepAliveLoop(session SSHSession) chan struct{} { terminate := make(chan struct{}) go func() { for { @@ -235,7 +239,7 @@ func (c Connection) startKeepAliveLoop(session SSHSession) chan struct{} { return terminate } -func (c Connection) Username() string { +func (c *Connection) Username() string { return c.sshConfig.User } diff --git a/ssh/connection_test.go b/ssh/connection_test.go index a084ab93a..ec88155d8 100644 --- a/ssh/connection_test.go +++ b/ssh/connection_test.go @@ -7,6 +7,7 @@ import ( "io" "log" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ssh" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ssh/fakes" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/testcluster" @@ -52,7 +53,7 @@ var _ = Describe("Connection", func() { JustBeforeEach(func() { - conn, connErr = ssh.NewConnection(hostname, user, privateKey, gossh.FixedHostKey(hostPublicKey), []string{"rsa-sha2-256"}, logger) + conn, connErr = ssh.NewConnection(hostname, user, privateKey, gossh.FixedHostKey(hostPublicKey), []string{"rsa-sha2-256"}, ratelimiter.NoOpRateLimiter{}, logger) }) Describe("Connection Creation", func() { @@ -329,7 +330,7 @@ var _ = Describe("Connection", func() { echo "start" sleep 4 echo "end"`) - conn, connErr = ssh.NewConnectionWithServerAliveInterval(hostname, user, privateKey, gossh.FixedHostKey(hostPublicKey), []string{"rsa-sha2-256"}, 1, logger) + conn, connErr = ssh.NewConnectionWithServerAliveInterval(hostname, user, privateKey, gossh.FixedHostKey(hostPublicKey), []string{"rsa-sha2-256"}, 1, ratelimiter.NoOpRateLimiter{}, logger) Expect(connErr).NotTo(HaveOccurred()) stdOut, _, _, _ = conn.Run("/tmp/produce") @@ -360,6 +361,7 @@ var _ = Describe("Connection", func() { gossh.FixedHostKey(hostPublicKey), []string{"rsa-sha2-256"}, rapidKeepAliveSignalInterval, + ratelimiter.NoOpRateLimiter{}, logger) Expect(connErr).NotTo(HaveOccurred()) stdErr, _, runError = conn.Stream(command, stdout) diff --git a/ssh/fakes/fake_remote_runner_factory.go b/ssh/fakes/fake_remote_runner_factory.go index 2aeba4ec3..f8b31d31b 100644 --- a/ssh/fakes/fake_remote_runner_factory.go +++ b/ssh/fakes/fake_remote_runner_factory.go @@ -4,12 +4,13 @@ package fakes import ( "sync" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ssh" ssha "golang.org/x/crypto/ssh" ) type FakeRemoteRunnerFactory struct { - Stub func(string, string, string, ssha.HostKeyCallback, []string, ssh.Logger) (ssh.RemoteRunner, error) + Stub func(string, string, string, ssha.HostKeyCallback, []string, ratelimiter.RateLimiter, ssh.Logger) (ssh.RemoteRunner, error) mutex sync.RWMutex argsForCall []struct { arg1 string @@ -17,7 +18,8 @@ type FakeRemoteRunnerFactory struct { arg3 string arg4 ssha.HostKeyCallback arg5 []string - arg6 ssh.Logger + arg6 ratelimiter.RateLimiter + arg7 ssh.Logger } returns struct { result1 ssh.RemoteRunner @@ -31,7 +33,7 @@ type FakeRemoteRunnerFactory struct { invocationsMutex sync.RWMutex } -func (fake *FakeRemoteRunnerFactory) Spy(arg1 string, arg2 string, arg3 string, arg4 ssha.HostKeyCallback, arg5 []string, arg6 ssh.Logger) (ssh.RemoteRunner, error) { +func (fake *FakeRemoteRunnerFactory) Spy(arg1 string, arg2 string, arg3 string, arg4 ssha.HostKeyCallback, arg5 []string, arg6 ratelimiter.RateLimiter, arg7 ssh.Logger) (ssh.RemoteRunner, error) { var arg5Copy []string if arg5 != nil { arg5Copy = make([]string, len(arg5)) @@ -45,14 +47,15 @@ func (fake *FakeRemoteRunnerFactory) Spy(arg1 string, arg2 string, arg3 string, arg3 string arg4 ssha.HostKeyCallback arg5 []string - arg6 ssh.Logger - }{arg1, arg2, arg3, arg4, arg5Copy, arg6}) + arg6 ratelimiter.RateLimiter + arg7 ssh.Logger + }{arg1, arg2, arg3, arg4, arg5Copy, arg6, arg7}) stub := fake.Stub returns := fake.returns fake.recordInvocation("RemoteRunnerFactory", []interface{}{arg1, arg2, arg3, arg4, arg5Copy, arg6}) fake.mutex.Unlock() if stub != nil { - return stub(arg1, arg2, arg3, arg4, arg5, arg6) + return stub(arg1, arg2, arg3, arg4, arg5, arg6, arg7) } if specificReturn { return ret.result1, ret.result2 @@ -66,16 +69,16 @@ func (fake *FakeRemoteRunnerFactory) CallCount() int { return len(fake.argsForCall) } -func (fake *FakeRemoteRunnerFactory) Calls(stub func(string, string, string, ssha.HostKeyCallback, []string, ssh.Logger) (ssh.RemoteRunner, error)) { +func (fake *FakeRemoteRunnerFactory) Calls(stub func(string, string, string, ssha.HostKeyCallback, []string, ratelimiter.RateLimiter, ssh.Logger) (ssh.RemoteRunner, error)) { fake.mutex.Lock() defer fake.mutex.Unlock() fake.Stub = stub } -func (fake *FakeRemoteRunnerFactory) ArgsForCall(i int) (string, string, string, ssha.HostKeyCallback, []string, ssh.Logger) { +func (fake *FakeRemoteRunnerFactory) ArgsForCall(i int) (string, string, string, ssha.HostKeyCallback, []string, ratelimiter.RateLimiter, ssh.Logger) { fake.mutex.RLock() defer fake.mutex.RUnlock() - return fake.argsForCall[i].arg1, fake.argsForCall[i].arg2, fake.argsForCall[i].arg3, fake.argsForCall[i].arg4, fake.argsForCall[i].arg5, fake.argsForCall[i].arg6 + return fake.argsForCall[i].arg1, fake.argsForCall[i].arg2, fake.argsForCall[i].arg3, fake.argsForCall[i].arg4, fake.argsForCall[i].arg5, fake.argsForCall[i].arg6, fake.argsForCall[i].arg7 } func (fake *FakeRemoteRunnerFactory) Returns(result1 ssh.RemoteRunner, result2 error) { diff --git a/ssh/remote_runner.go b/ssh/remote_runner.go index 4828bbd5b..b4b8c4496 100644 --- a/ssh/remote_runner.go +++ b/ssh/remote_runner.go @@ -8,6 +8,7 @@ import ( "golang.org/x/crypto/ssh" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" "github.com/pkg/errors" ) @@ -34,8 +35,8 @@ type SshRemoteRunner struct { connection SSHConnection } -func NewSshRemoteRunner(host, user, privateKey string, publicKeyCallback ssh.HostKeyCallback, publicKeyAlgorithm []string, logger Logger) (RemoteRunner, error) { - connection, err := NewConnection(host, user, privateKey, publicKeyCallback, publicKeyAlgorithm, logger) +func NewSshRemoteRunner(host, user, privateKey string, publicKeyCallback ssh.HostKeyCallback, publicKeyAlgorithm []string, rateLimiter ratelimiter.RateLimiter, logger Logger) (RemoteRunner, error) { + connection, err := NewConnection(host, user, privateKey, publicKeyCallback, publicKeyAlgorithm, rateLimiter, logger) if err != nil { return SshRemoteRunner{}, err } diff --git a/ssh/remote_runner_factory.go b/ssh/remote_runner_factory.go index 15df18cd3..a300c05a6 100644 --- a/ssh/remote_runner_factory.go +++ b/ssh/remote_runner_factory.go @@ -1,7 +1,10 @@ package ssh -import "golang.org/x/crypto/ssh" +import ( + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" + "golang.org/x/crypto/ssh" +) //go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate //counterfeiter:generate -o fakes/fake_remote_runner_factory.go . RemoteRunnerFactory -type RemoteRunnerFactory func(host, user, privateKey string, publicKeyCallback ssh.HostKeyCallback, publicKeyAlgorithm []string, logger Logger) (RemoteRunner, error) +type RemoteRunnerFactory func(host, user, privateKey string, publicKeyCallback ssh.HostKeyCallback, publicKeyAlgorithm []string, rateLimiter ratelimiter.RateLimiter, logger Logger) (RemoteRunner, error) diff --git a/ssh/remote_runner_test.go b/ssh/remote_runner_test.go index 70d451fa2..2c1f744a9 100644 --- a/ssh/remote_runner_test.go +++ b/ssh/remote_runner_test.go @@ -7,6 +7,7 @@ import ( "os" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ssh" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/testcluster" boshlog "github.com/cloudfoundry/bosh-utils/logger" @@ -37,11 +38,11 @@ var _ = Describe("SshRemoteRunner", func() { logger := boshlog.New(boshlog.LevelDebug, combinedLog) sshConnection, err = ssh.NewConnection(testInstance.Address(), user, userPrivateKey, gossh.FixedHostKey(hostPublicKey), - []string{"rsa-sha2-256"}, logger) + []string{"rsa-sha2-256"}, ratelimiter.NoOpRateLimiter{}, logger) Expect(err).NotTo(HaveOccurred()) sshRemoteRunner, err = ssh.NewSshRemoteRunner(testInstance.Address(), user, userPrivateKey, gossh.FixedHostKey(hostPublicKey), - []string{"rsa-sha2-256"}, logger) + []string{"rsa-sha2-256"}, ratelimiter.NoOpRateLimiter{}, logger) Expect(err).NotTo(HaveOccurred()) }) diff --git a/standalone/deployment_manager.go b/standalone/deployment_manager.go index 8a75bdc13..e74ff770d 100644 --- a/standalone/deployment_manager.go +++ b/standalone/deployment_manager.go @@ -5,6 +5,7 @@ import ( "github.com/cloudfoundry-incubator/bosh-backup-and-restore/instance" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/orchestrator" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ssh" "github.com/pkg/errors" @@ -18,6 +19,7 @@ type DeploymentManager struct { privateKeyFile string jobFinder instance.JobFinder remoteRunnerFactory ssh.RemoteRunnerFactory + rateLimiter ratelimiter.RateLimiter } func NewDeploymentManager( @@ -25,6 +27,7 @@ func NewDeploymentManager( hostName, username, privateKey string, jobFinder instance.JobFinder, remoteRunnerFactory ssh.RemoteRunnerFactory, + rateLimiter ratelimiter.RateLimiter, ) DeploymentManager { return DeploymentManager{ Logger: logger, @@ -33,6 +36,7 @@ func NewDeploymentManager( privateKeyFile: privateKey, jobFinder: jobFinder, remoteRunnerFactory: remoteRunnerFactory, + rateLimiter: rateLimiter, } } @@ -42,7 +46,7 @@ func (dm DeploymentManager) Find(deploymentName string) (orchestrator.Deployment return nil, errors.Wrap(err, "failed reading private key") } - remoteRunner, err := dm.remoteRunnerFactory(dm.hostName, dm.username, string(keyContents), gossh.InsecureIgnoreHostKey(), nil, dm.Logger) + remoteRunner, err := dm.remoteRunnerFactory(dm.hostName, dm.username, string(keyContents), gossh.InsecureIgnoreHostKey(), nil, dm.rateLimiter, dm.Logger) if err != nil { return nil, err } diff --git a/standalone/deployment_manager_test.go b/standalone/deployment_manager_test.go index 7cb55ab96..b0292e3df 100644 --- a/standalone/deployment_manager_test.go +++ b/standalone/deployment_manager_test.go @@ -5,6 +5,7 @@ import ( "os" "github.com/cloudfoundry-incubator/bosh-backup-and-restore/instance" + "github.com/cloudfoundry-incubator/bosh-backup-and-restore/ratelimiter" . "github.com/cloudfoundry-incubator/bosh-backup-and-restore/standalone" instancefakes "github.com/cloudfoundry-incubator/bosh-backup-and-restore/instance/fakes" @@ -35,7 +36,7 @@ var _ = Describe("DeploymentManager", func() { fakeJobFinder = new(instancefakes.FakeJobFinder) remoteRunner = new(sshfakes.FakeRemoteRunner) - deploymentManager = NewDeploymentManager(logger, hostName, username, privateKey, fakeJobFinder, remoteRunnerFactory.Spy) + deploymentManager = NewDeploymentManager(logger, hostName, username, privateKey, fakeJobFinder, remoteRunnerFactory.Spy, &ratelimiter.ConnectionRateLimiter{}) }) AfterEach(func() {