diff --git a/internal/provider/provider.go b/internal/provider/provider.go index e80c43c..b34cbcb 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -39,7 +39,14 @@ func New(version string) func() *schema.Provider { ResourcesMap: map[string]*schema.Resource{ "remotefile": resourceRemotefile(), }, - Schema: map[string]*schema.Schema{}, + Schema: map[string]*schema.Schema{ + "max_sessions": { + Type: schema.TypeInt, + Optional: true, + Default: 3, + Description: "Maximum number of open sessions in each host connection.", + }, + }, } p.ConfigureContextFunc = configure(version, p) @@ -49,15 +56,19 @@ func New(version string) func() *schema.Provider { } type apiClient struct { - mux *sync.Mutex - remoteClients map[string]*RemoteClient + mux *sync.Mutex + remoteClients map[string]*RemoteClient + activeSessions map[string]int + maxSessions int } func configure(version string, p *schema.Provider) func(context.Context, *schema.ResourceData) (interface{}, diag.Diagnostics) { return func(c context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) { client := apiClient{ - mux: &sync.Mutex{}, - remoteClients: map[string]*RemoteClient{}, + maxSessions: d.Get("max_sessions").(int), + mux: &sync.Mutex{}, + remoteClients: map[string]*RemoteClient{}, + activeSessions: map[string]int{}, } return &client, diag.Diagnostics{} @@ -66,21 +77,45 @@ func configure(version string, p *schema.Provider) func(context.Context, *schema func (c *apiClient) getRemoteClient(d *schema.ResourceData) (*RemoteClient, error) { connectionID := resourceConnectionHash(d) - c.mux.Lock() defer c.mux.Unlock() + for { + c.mux.Lock() + + client, ok := c.remoteClients[connectionID] + if ok { + if c.activeSessions[connectionID] >= c.maxSessions { + c.mux.Unlock() + continue + } + c.activeSessions[connectionID] += 1 + + return client, nil + } - client, ok := c.remoteClients[connectionID] - if ok { + client, err := RemoteClientFromResource(d) + if err != nil { + return nil, err + } + + c.remoteClients[connectionID] = client + c.activeSessions[connectionID] = 1 return client, nil } +} + +func (c *apiClient) closeRemoteClient(d *schema.ResourceData) error { + connectionID := resourceConnectionHash(d) + c.mux.Lock() + defer c.mux.Unlock() - client, err := RemoteClientFromResource(d) - if err != nil { - return nil, err + c.activeSessions[connectionID] -= 1 + if c.activeSessions[connectionID] == 0 { + client := c.remoteClients[connectionID] + delete(c.remoteClients, connectionID) + return client.Close() } - c.remoteClients[connectionID] = client - return client, nil + return nil } func RemoteClientFromResource(d *schema.ResourceData) (*RemoteClient, error) { diff --git a/internal/provider/remote_client.go b/internal/provider/remote_client.go index 684baef..754b630 100644 --- a/internal/provider/remote_client.go +++ b/internal/provider/remote_client.go @@ -166,6 +166,10 @@ func NewRemoteClient(host string, clientConfig ssh.ClientConfig) (*RemoteClient, }, nil } +func (c *RemoteClient) Close() error { + return c.sshClient.Close() +} + func (c *RemoteClient) GetSSHClient() *ssh.Client { return c.sshClient } diff --git a/internal/provider/resource_remotefile.go b/internal/provider/resource_remotefile.go index c92beef..ffc64be 100644 --- a/internal/provider/resource_remotefile.go +++ b/internal/provider/resource_remotefile.go @@ -100,7 +100,7 @@ func resourceRemotefileCreate(ctx context.Context, d *schema.ResourceData, meta client, err := meta.(*apiClient).getRemoteClient(d) if err != nil { - return diag.Errorf(err.Error()) + return diag.Errorf("error while opening remote client: %s", err.Error()) } sudo, ok := d.GetOk("conn.0.sudo") @@ -120,6 +120,11 @@ func resourceRemotefileCreate(ctx context.Context, d *schema.ResourceData, meta } } + err = meta.(*apiClient).closeRemoteClient(d) + if err != nil { + return diag.Errorf("error while closing remote client: %s", err.Error()) + } + return diag.Diagnostics{} } @@ -128,7 +133,7 @@ func resourceRemotefileRead(ctx context.Context, d *schema.ResourceData, meta in client, err := meta.(*apiClient).getRemoteClient(d) if err != nil { - return diag.Errorf(err.Error()) + return diag.Errorf("error while opening remote client: %s", err.Error()) } sudo, ok := d.GetOk("conn.0.sudo") @@ -152,6 +157,11 @@ func resourceRemotefileRead(ctx context.Context, d *schema.ResourceData, meta in } } + err = meta.(*apiClient).closeRemoteClient(d) + if err != nil { + return diag.Errorf("error while closing remote client: %s", err.Error()) + } + return diag.Diagnostics{} } @@ -162,7 +172,7 @@ func resourceRemotefileUpdate(ctx context.Context, d *schema.ResourceData, meta func resourceRemotefileDelete(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics { client, err := meta.(*apiClient).getRemoteClient(d) if err != nil { - return diag.Errorf(err.Error()) + return diag.Errorf("error while opening remote client: %s", err.Error()) } sudo, ok := d.GetOk("conn.0.sudo") @@ -186,5 +196,10 @@ func resourceRemotefileDelete(ctx context.Context, d *schema.ResourceData, meta } } + err = meta.(*apiClient).closeRemoteClient(d) + if err != nil { + return diag.Errorf("error while closing remote client: %s", err.Error()) + } + return diag.Diagnostics{} }