From f072bbb5b0777cd2feecb9b73c2c09d97629ba1d Mon Sep 17 00:00:00 2001 From: Amund Tenstad Date: Thu, 6 May 2021 09:24:42 +0200 Subject: [PATCH] Avoid multiple connections --- docs/data-sources/data_source.md | 9 +- docs/resources/resource.md | 4 +- internal/provider/provider.go | 238 +++++------------------ internal/provider/remote_client.go | 179 +++++++++++++++++ internal/provider/resource_remotefile.go | 24 +-- 5 files changed, 245 insertions(+), 209 deletions(-) create mode 100644 internal/provider/remote_client.go diff --git a/docs/data-sources/data_source.md b/docs/data-sources/data_source.md index d40f027..ab4dd5a 100644 --- a/docs/data-sources/data_source.md +++ b/docs/data-sources/data_source.md @@ -14,10 +14,10 @@ description: |- ```terraform data "remotefile" "bar" { conn { - host = "foo.com" - port = "22" - username = "foo" - private_key = "" + host = "foo.com" + port = 22 + username = "foo" + password = "" } path = "/tmp/bar.txt" } @@ -27,6 +27,7 @@ data "remotefile" "bar" { ### Required +- **conn** (Object, Required) Connection to remote host. - **path** (String, Required) Path to file on remote host. ### Optional diff --git a/docs/resources/resource.md b/docs/resources/resource.md index 02269b5..324fb41 100644 --- a/docs/resources/resource.md +++ b/docs/resources/resource.md @@ -15,8 +15,9 @@ File on remote host. resource "remotefile" "foo" { conn { host = "foo.com" - port = "22" + port = 22 username = "foo" + sudo = true private_key = "" } path = "/tmp/foo.txt" @@ -30,6 +31,7 @@ resource "remotefile" "foo" { ### Required +- **conn** (Object, Required) Connection to remote host. - **path** (String, Required) Path to file on remote host. - **content** (String, Required) Content of file on remote host. diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 82f12b6..8bc78ff 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -1,16 +1,15 @@ package provider import ( - "bytes" "context" "fmt" "io/ioutil" + "strconv" "strings" + "sync" - "github.com/bramvdbogaerde/go-scp" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" - "github.com/pkg/sftp" "golang.org/x/crypto/ssh" ) @@ -49,17 +48,41 @@ func New(version string) func() *schema.Provider { } type apiClient struct { - clientConfig ssh.ClientConfig - host string + mux *sync.Mutex + remoteClients map[string]*RemoteClient } func configure(version string, p *schema.Provider) func(context.Context, *schema.ResourceData) (interface{}, diag.Diagnostics) { return func(c context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) { - return &apiClient{}, diag.Diagnostics{} + client := apiClient{ + mux: &sync.Mutex{}, + remoteClients: map[string]*RemoteClient{}, + } + + return &client, diag.Diagnostics{} } } -func newClient(d *schema.ResourceData) (*apiClient, error) { +func (c *apiClient) getRemoteClient(d *schema.ResourceData) (*RemoteClient, error) { + connectionID := resourceConnectionHash(d) + c.mux.Lock() + defer c.mux.Unlock() + + client, ok := c.remoteClients[connectionID] + if ok { + return client, nil + } + + client, err := RemoteClientFromResource(d) + if err != nil { + return nil, err + } + + c.remoteClients[connectionID] = client + return client, nil +} + +func RemoteClientFromResource(d *schema.ResourceData) (*RemoteClient, error) { clientConfig := ssh.ClientConfig{ User: d.Get("conn.0.username").(string), HostKeyCallback: ssh.InsecureIgnoreHostKey(), @@ -92,195 +115,26 @@ func newClient(d *schema.ResourceData) (*apiClient, error) { clientConfig.Auth = append(clientConfig.Auth, ssh.PublicKeys(signer)) } - client := apiClient{ - clientConfig: clientConfig, - host: fmt.Sprintf("%s:%d", d.Get("conn.0.host").(string), d.Get("conn.0.port").(int)), - } - - return &client, nil -} - -func (c *apiClient) writeFile(d *schema.ResourceData) error { - scpClient, err := c.getSCPClient() - if err != nil { - return err - } - defer scpClient.Close() - - return scpClient.CopyFile(strings.NewReader(d.Get("content").(string)), d.Get("path").(string), d.Get("permissions").(string)) -} - -func (c *apiClient) writeFileSudo(d *schema.ResourceData) error { - sshClient, err := c.getSSHClient() - if err != nil { - return err - } - - session, err := sshClient.NewSession() - if err != nil { - return err - } - defer session.Close() - - stdin, err := session.StdinPipe() - if err != nil { - return err - } - - content := d.Get("content").(string) - go func() { - stdin.Write([]byte(content)) - stdin.Close() - }() - - cmd := fmt.Sprintf("cat /dev/stdin | sudo tee %s", d.Get("path").(string)) - return session.Run(cmd) -} - -func (c *apiClient) chmodFileSudo(d *schema.ResourceData) error { - sshClient, err := c.getSSHClient() - if err != nil { - return err - } - - session, err := sshClient.NewSession() - if err != nil { - return err - } - defer session.Close() - - cmd := fmt.Sprintf("sudo chmod %s %s", d.Get("permissions").(string), d.Get("path").(string)) - return session.Run(cmd) -} - -func (c *apiClient) readFile(d *schema.ResourceData) error { - sftpClient, err := c.getSFTPClient() - if err != nil { - return err - } - defer sftpClient.Close() - - file, err := sftpClient.Open(d.Get("path").(string)) - if err != nil { - return err - } - defer file.Close() - - content := bytes.Buffer{} - _, err = file.WriteTo(&content) - - if err != nil { - return err - } - - d.Set("content", string(content.String())) - return nil + host := fmt.Sprintf("%s:%d", d.Get("conn.0.host").(string), d.Get("conn.0.port").(int)) + return NewRemoteClient(host, clientConfig) } -func (c *apiClient) fileExistsSudo(d *schema.ResourceData) (bool, error) { - sshClient, err := c.getSSHClient() - if err != nil { - return false, err - } - - session, err := sshClient.NewSession() - if err != nil { - return false, err - } - defer session.Close() - - path := d.Get("path").(string) - cmd := fmt.Sprintf("test -f %s", path) - err = session.Run(cmd) - - if err != nil { - session2, err := sshClient.NewSession() - if err != nil { - return false, err - } - defer session2.Close() - - cmd := fmt.Sprintf("test ! -f %s", path) - return false, session2.Run(cmd) - } - - return true, nil -} - -func (c *apiClient) readFileSudo(d *schema.ResourceData) error { - sshClient, err := c.getSSHClient() - if err != nil { - return err +func resourceConnectionHash(d *schema.ResourceData) string { + elements := []string{ + d.Get("conn.0.host").(string), + d.Get("conn.0.username").(string), + strconv.Itoa(d.Get("conn.0.port").(int)), + resourceStringWithDefault(d, "conn.0.password", ""), + resourceStringWithDefault(d, "conn.0.private_key", ""), + resourceStringWithDefault(d, "conn.0.private_key_path", ""), } - - session, err := sshClient.NewSession() - if err != nil { - return err - } - defer session.Close() - - cmd := fmt.Sprintf("sudo cat %s", d.Get("path").(string)) - content, err := session.Output(cmd) - if err != nil { - return err - } - - d.Set("content", string(content)) - return nil + return strings.Join(elements, "::") } -func (c *apiClient) deleteFile(d *schema.ResourceData) error { - sftpClient, err := c.getSFTPClient() - if err != nil { - return err - } - defer sftpClient.Close() - - return sftpClient.Remove(d.Get("path").(string)) -} - -func (c *apiClient) deleteFileSudo(d *schema.ResourceData) error { - sshClient, err := c.getSSHClient() - if err != nil { - return err - } - - session, err := sshClient.NewSession() - if err != nil { - return err - } - defer session.Close() - - cmd := fmt.Sprintf("sudo cat %s", d.Get("path").(string)) - return session.Run(cmd) -} - -func (c apiClient) getSSHClient() (*ssh.Client, error) { - sshClient, err := ssh.Dial("tcp", c.host, &c.clientConfig) - if err != nil { - return nil, fmt.Errorf("couldn't establish a connection to the remote server: %s", err.Error()) - } - return sshClient, nil -} - -func (c apiClient) getSCPClient() (*scp.Client, error) { - scpClient := scp.NewClient(c.host, &c.clientConfig) - err := scpClient.Connect() - if err != nil { - return nil, fmt.Errorf("couldn't establish a connection to the remote server: %s", err.Error()) - } - return &scpClient, nil -} - -func (c apiClient) getSFTPClient() (*sftp.Client, error) { - sshClient, err := c.getSSHClient() - if err != nil { - return nil, err - } - - sftp, err := sftp.NewClient(sshClient) - if err != nil { - return nil, err +func resourceStringWithDefault(d *schema.ResourceData, key string, defaultValue string) string { + str, ok := d.GetOk(key) + if ok { + return str.(string) } - return sftp, nil + return defaultValue } diff --git a/internal/provider/remote_client.go b/internal/provider/remote_client.go new file mode 100644 index 0000000..82bc2c5 --- /dev/null +++ b/internal/provider/remote_client.go @@ -0,0 +1,179 @@ +package provider + +import ( + "bytes" + "fmt" + "strings" + + "github.com/bramvdbogaerde/go-scp" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +type RemoteClient struct { + sshClient *ssh.Client +} + +func (c *RemoteClient) WriteFile(d *schema.ResourceData) error { + scpClient, err := c.GetSCPClient() + if err != nil { + return err + } + defer scpClient.Close() + + return scpClient.CopyFile(strings.NewReader(d.Get("content").(string)), d.Get("path").(string), d.Get("permissions").(string)) +} + +func (c *RemoteClient) WriteFileSudo(d *schema.ResourceData) error { + sshClient := c.GetSSHClient() + + session, err := sshClient.NewSession() + if err != nil { + return err + } + defer session.Close() + + stdin, err := session.StdinPipe() + if err != nil { + return err + } + + content := d.Get("content").(string) + go func() { + stdin.Write([]byte(content)) + stdin.Close() + }() + + cmd := fmt.Sprintf("cat /dev/stdin | sudo tee %s", d.Get("path").(string)) + return session.Run(cmd) +} + +func (c *RemoteClient) ChmodFileSudo(d *schema.ResourceData) error { + sshClient := c.GetSSHClient() + + session, err := sshClient.NewSession() + if err != nil { + return err + } + defer session.Close() + + cmd := fmt.Sprintf("sudo chmod %s %s", d.Get("permissions").(string), d.Get("path").(string)) + return session.Run(cmd) +} + +func (c *RemoteClient) FileExistsSudo(d *schema.ResourceData) (bool, error) { + sshClient := c.GetSSHClient() + + session, err := sshClient.NewSession() + if err != nil { + return false, err + } + defer session.Close() + + path := d.Get("path").(string) + cmd := fmt.Sprintf("test -f %s", path) + err = session.Run(cmd) + + if err != nil { + session2, err := sshClient.NewSession() + if err != nil { + return false, err + } + defer session2.Close() + + cmd := fmt.Sprintf("test ! -f %s", path) + return false, session2.Run(cmd) + } + + return true, nil +} + +func (c *RemoteClient) ReadFile(d *schema.ResourceData) error { + sftpClient, err := c.GetSFTPClient() + if err != nil { + return err + } + defer sftpClient.Close() + + file, err := sftpClient.Open(d.Get("path").(string)) + if err != nil { + return err + } + defer file.Close() + + content := bytes.Buffer{} + _, err = file.WriteTo(&content) + + if err != nil { + return err + } + + d.Set("content", string(content.String())) + return nil +} + +func (c *RemoteClient) ReadFileSudo(d *schema.ResourceData) error { + sshClient := c.GetSSHClient() + + session, err := sshClient.NewSession() + if err != nil { + return err + } + defer session.Close() + + cmd := fmt.Sprintf("sudo cat %s", d.Get("path").(string)) + content, err := session.Output(cmd) + if err != nil { + return err + } + + d.Set("content", string(content)) + return nil +} + +func (c *RemoteClient) DeleteFile(d *schema.ResourceData) error { + sftpClient, err := c.GetSFTPClient() + if err != nil { + return err + } + defer sftpClient.Close() + + return sftpClient.Remove(d.Get("path").(string)) +} + +func (c *RemoteClient) DeleteFileSudo(d *schema.ResourceData) error { + sshClient := c.GetSSHClient() + + session, err := sshClient.NewSession() + if err != nil { + return err + } + defer session.Close() + + cmd := fmt.Sprintf("sudo cat %s", d.Get("path").(string)) + return session.Run(cmd) +} + +func NewRemoteClient(host string, clientConfig ssh.ClientConfig) (*RemoteClient, error) { + client, err := ssh.Dial("tcp", host, &clientConfig) + if err != nil { + return nil, fmt.Errorf("couldn't establish a connection to the remote server: %s", err.Error()) + } + + return &RemoteClient{ + sshClient: client, + }, nil +} + +func (c *RemoteClient) GetSSHClient() *ssh.Client { + return c.sshClient +} + +func (c *RemoteClient) GetSCPClient() (scp.Client, error) { + return scp.NewClientBySSH(c.sshClient) +} + +func (c *RemoteClient) GetSFTPClient() (*sftp.Client, error) { + return sftp.NewClient(c.sshClient) +} diff --git a/internal/provider/resource_remotefile.go b/internal/provider/resource_remotefile.go index 07d7dd7..6e69c60 100644 --- a/internal/provider/resource_remotefile.go +++ b/internal/provider/resource_remotefile.go @@ -93,23 +93,23 @@ func resourceRemotefile() *schema.Resource { func resourceRemotefileCreate(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { d.SetId(fmt.Sprintf("%s:%s", d.Get("conn.0.host").(string), d.Get("path").(string))) - client, err := newClient(d) + client, err := meta.(*apiClient).getRemoteClient(d) if err != nil { return diag.Errorf(err.Error()) } sudo, ok := d.GetOk("conn.0.sudo") if ok && sudo.(bool) { - err := client.writeFileSudo(d) + err := client.WriteFileSudo(d) if err != nil { return diag.Errorf("error while creating remote file with sudo: %s", err.Error()) } - err = client.chmodFileSudo(d) + err = client.ChmodFileSudo(d) if err != nil { return diag.Errorf("error while changing permissions of remote file with sudo: %s", err.Error()) } } else { - err := client.writeFile(d) + err := client.WriteFile(d) if err != nil { return diag.Errorf("error while creating remote file: %s", err.Error()) } @@ -121,19 +121,19 @@ func resourceRemotefileCreate(ctx context.Context, d *schema.ResourceData, meta func resourceRemotefileRead(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { d.SetId(fmt.Sprintf("%s:%s", d.Get("conn.0.host").(string), d.Get("path").(string))) - client, err := newClient(d) + client, err := meta.(*apiClient).getRemoteClient(d) if err != nil { return diag.Errorf(err.Error()) } sudo, ok := d.GetOk("conn.0.sudo") if ok && sudo.(bool) { - exists, err := client.fileExistsSudo(d) + exists, err := client.FileExistsSudo(d) if err != nil { return diag.Errorf("error while checking if remote file exists with sudo: %s", err.Error()) } if exists { - err := client.readFileSudo(d) + err := client.ReadFileSudo(d) if err != nil { return diag.Errorf("error while reading remote file with sudo: %s", err.Error()) } @@ -141,7 +141,7 @@ func resourceRemotefileRead(ctx context.Context, d *schema.ResourceData, meta in return diag.Errorf("cannot read file, it does not exist.") } } else { - err := client.readFile(d) + err := client.ReadFile(d) if err != nil { return diag.Errorf("error while reading remote file: %s", err.Error()) } @@ -155,19 +155,19 @@ func resourceRemotefileUpdate(ctx context.Context, d *schema.ResourceData, meta } func resourceRemotefileDelete(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { - client, err := newClient(d) + client, err := meta.(*apiClient).getRemoteClient(d) if err != nil { return diag.Errorf(err.Error()) } sudo, ok := d.GetOk("conn.0.sudo") if ok && sudo.(bool) { - exists, err := client.fileExistsSudo(d) + exists, err := client.FileExistsSudo(d) if err != nil { return diag.Errorf("error while checking if remote file exists with sudo: %s", err.Error()) } if exists { - err := client.deleteFileSudo(d) + err := client.DeleteFileSudo(d) if err != nil { return diag.Errorf("error while removing remote file with sudo: %s", err.Error()) } @@ -175,7 +175,7 @@ func resourceRemotefileDelete(ctx context.Context, d *schema.ResourceData, meta return diag.Errorf("cannot delete file, it does not exist.") } } else { - err := client.deleteFile(d) + err := client.DeleteFile(d) if err != nil { return diag.Errorf("error while removing remote file: %s", err.Error()) }