Skip to content

Commit

Permalink
Avoid multiple connections
Browse files Browse the repository at this point in the history
  • Loading branch information
tenstad committed May 6, 2021
1 parent caec48b commit f072bbb
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 209 deletions.
9 changes: 5 additions & 4 deletions docs/data-sources/data_source.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ description: |-
```terraform
data "remotefile" "bar" {
conn {
host = "foo.com"
port = "22"
username = "foo"
private_key = "<ssh private key>"
host = "foo.com"
port = 22
username = "foo"
password = "<password>"
}
path = "/tmp/bar.txt"
}
Expand All @@ -27,6 +27,7 @@ data "remotefile" "bar" {

### Required

- **conn** (Object, Required) Connection to remote host.
- **path** (String, Required) Path to file on remote host.

### Optional
Expand Down
4 changes: 3 additions & 1 deletion docs/resources/resource.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ File on remote host.
resource "remotefile" "foo" {
conn {
host = "foo.com"
port = "22"
port = 22
username = "foo"
sudo = true
private_key = "<ssh private key>"
}
path = "/tmp/foo.txt"
Expand All @@ -30,6 +31,7 @@ resource "remotefile" "foo" {

### Required

- **conn** (Object, Required) Connection to remote host.
- **path** (String, Required) Path to file on remote host.
- **content** (String, Required) Content of file on remote host.

Expand Down
238 changes: 46 additions & 192 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
package provider

import (
"bytes"
"context"
"fmt"
"io/ioutil"
"strconv"
"strings"
"sync"

"github.com/bramvdbogaerde/go-scp"
"github.com/hashicorp/terraform-plugin-sdk/v2/diag"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)

Expand Down Expand Up @@ -49,17 +48,41 @@ func New(version string) func() *schema.Provider {
}

type apiClient struct {
clientConfig ssh.ClientConfig
host string
mux *sync.Mutex
remoteClients map[string]*RemoteClient
}

func configure(version string, p *schema.Provider) func(context.Context, *schema.ResourceData) (interface{}, diag.Diagnostics) {
return func(c context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) {
return &apiClient{}, diag.Diagnostics{}
client := apiClient{
mux: &sync.Mutex{},
remoteClients: map[string]*RemoteClient{},
}

return &client, diag.Diagnostics{}
}
}

func newClient(d *schema.ResourceData) (*apiClient, error) {
func (c *apiClient) getRemoteClient(d *schema.ResourceData) (*RemoteClient, error) {
connectionID := resourceConnectionHash(d)
c.mux.Lock()
defer c.mux.Unlock()

client, ok := c.remoteClients[connectionID]
if ok {
return client, nil
}

client, err := RemoteClientFromResource(d)
if err != nil {
return nil, err
}

c.remoteClients[connectionID] = client
return client, nil
}

func RemoteClientFromResource(d *schema.ResourceData) (*RemoteClient, error) {
clientConfig := ssh.ClientConfig{
User: d.Get("conn.0.username").(string),
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Expand Down Expand Up @@ -92,195 +115,26 @@ func newClient(d *schema.ResourceData) (*apiClient, error) {
clientConfig.Auth = append(clientConfig.Auth, ssh.PublicKeys(signer))
}

client := apiClient{
clientConfig: clientConfig,
host: fmt.Sprintf("%s:%d", d.Get("conn.0.host").(string), d.Get("conn.0.port").(int)),
}

return &client, nil
}

func (c *apiClient) writeFile(d *schema.ResourceData) error {
scpClient, err := c.getSCPClient()
if err != nil {
return err
}
defer scpClient.Close()

return scpClient.CopyFile(strings.NewReader(d.Get("content").(string)), d.Get("path").(string), d.Get("permissions").(string))
}

func (c *apiClient) writeFileSudo(d *schema.ResourceData) error {
sshClient, err := c.getSSHClient()
if err != nil {
return err
}

session, err := sshClient.NewSession()
if err != nil {
return err
}
defer session.Close()

stdin, err := session.StdinPipe()
if err != nil {
return err
}

content := d.Get("content").(string)
go func() {
stdin.Write([]byte(content))
stdin.Close()
}()

cmd := fmt.Sprintf("cat /dev/stdin | sudo tee %s", d.Get("path").(string))
return session.Run(cmd)
}

func (c *apiClient) chmodFileSudo(d *schema.ResourceData) error {
sshClient, err := c.getSSHClient()
if err != nil {
return err
}

session, err := sshClient.NewSession()
if err != nil {
return err
}
defer session.Close()

cmd := fmt.Sprintf("sudo chmod %s %s", d.Get("permissions").(string), d.Get("path").(string))
return session.Run(cmd)
}

func (c *apiClient) readFile(d *schema.ResourceData) error {
sftpClient, err := c.getSFTPClient()
if err != nil {
return err
}
defer sftpClient.Close()

file, err := sftpClient.Open(d.Get("path").(string))
if err != nil {
return err
}
defer file.Close()

content := bytes.Buffer{}
_, err = file.WriteTo(&content)

if err != nil {
return err
}

d.Set("content", string(content.String()))
return nil
host := fmt.Sprintf("%s:%d", d.Get("conn.0.host").(string), d.Get("conn.0.port").(int))
return NewRemoteClient(host, clientConfig)
}

func (c *apiClient) fileExistsSudo(d *schema.ResourceData) (bool, error) {
sshClient, err := c.getSSHClient()
if err != nil {
return false, err
}

session, err := sshClient.NewSession()
if err != nil {
return false, err
}
defer session.Close()

path := d.Get("path").(string)
cmd := fmt.Sprintf("test -f %s", path)
err = session.Run(cmd)

if err != nil {
session2, err := sshClient.NewSession()
if err != nil {
return false, err
}
defer session2.Close()

cmd := fmt.Sprintf("test ! -f %s", path)
return false, session2.Run(cmd)
}

return true, nil
}

func (c *apiClient) readFileSudo(d *schema.ResourceData) error {
sshClient, err := c.getSSHClient()
if err != nil {
return err
func resourceConnectionHash(d *schema.ResourceData) string {
elements := []string{
d.Get("conn.0.host").(string),
d.Get("conn.0.username").(string),
strconv.Itoa(d.Get("conn.0.port").(int)),
resourceStringWithDefault(d, "conn.0.password", ""),
resourceStringWithDefault(d, "conn.0.private_key", ""),
resourceStringWithDefault(d, "conn.0.private_key_path", ""),
}

session, err := sshClient.NewSession()
if err != nil {
return err
}
defer session.Close()

cmd := fmt.Sprintf("sudo cat %s", d.Get("path").(string))
content, err := session.Output(cmd)
if err != nil {
return err
}

d.Set("content", string(content))
return nil
return strings.Join(elements, "::")
}

func (c *apiClient) deleteFile(d *schema.ResourceData) error {
sftpClient, err := c.getSFTPClient()
if err != nil {
return err
}
defer sftpClient.Close()

return sftpClient.Remove(d.Get("path").(string))
}

func (c *apiClient) deleteFileSudo(d *schema.ResourceData) error {
sshClient, err := c.getSSHClient()
if err != nil {
return err
}

session, err := sshClient.NewSession()
if err != nil {
return err
}
defer session.Close()

cmd := fmt.Sprintf("sudo cat %s", d.Get("path").(string))
return session.Run(cmd)
}

func (c apiClient) getSSHClient() (*ssh.Client, error) {
sshClient, err := ssh.Dial("tcp", c.host, &c.clientConfig)
if err != nil {
return nil, fmt.Errorf("couldn't establish a connection to the remote server: %s", err.Error())
}
return sshClient, nil
}

func (c apiClient) getSCPClient() (*scp.Client, error) {
scpClient := scp.NewClient(c.host, &c.clientConfig)
err := scpClient.Connect()
if err != nil {
return nil, fmt.Errorf("couldn't establish a connection to the remote server: %s", err.Error())
}
return &scpClient, nil
}

func (c apiClient) getSFTPClient() (*sftp.Client, error) {
sshClient, err := c.getSSHClient()
if err != nil {
return nil, err
}

sftp, err := sftp.NewClient(sshClient)
if err != nil {
return nil, err
func resourceStringWithDefault(d *schema.ResourceData, key string, defaultValue string) string {
str, ok := d.GetOk(key)
if ok {
return str.(string)
}
return sftp, nil
return defaultValue
}
Loading

0 comments on commit f072bbb

Please sign in to comment.