Skip to content

Commit

Permalink
Set max sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
tenstad committed Jun 15, 2021
1 parent f9a0135 commit 95f6c52
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 16 deletions.
61 changes: 48 additions & 13 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,14 @@ func New(version string) func() *schema.Provider {
ResourcesMap: map[string]*schema.Resource{
"remotefile": resourceRemotefile(),
},
Schema: map[string]*schema.Schema{},
Schema: map[string]*schema.Schema{
"max_sessions": {
Type: schema.TypeInt,
Optional: true,
Default: 3,
Description: "Maximum number of open sessions in each host connection.",
},
},
}

p.ConfigureContextFunc = configure(version, p)
Expand All @@ -49,15 +56,19 @@ func New(version string) func() *schema.Provider {
}

type apiClient struct {
mux *sync.Mutex
remoteClients map[string]*RemoteClient
mux *sync.Mutex
remoteClients map[string]*RemoteClient
activeSessions map[string]int
maxSessions int
}

func configure(version string, p *schema.Provider) func(context.Context, *schema.ResourceData) (interface{}, diag.Diagnostics) {
return func(c context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) {
client := apiClient{
mux: &sync.Mutex{},
remoteClients: map[string]*RemoteClient{},
maxSessions: d.Get("max_sessions").(int),
mux: &sync.Mutex{},
remoteClients: map[string]*RemoteClient{},
activeSessions: map[string]int{},
}

return &client, diag.Diagnostics{}
Expand All @@ -66,21 +77,45 @@ func configure(version string, p *schema.Provider) func(context.Context, *schema

func (c *apiClient) getRemoteClient(d *schema.ResourceData) (*RemoteClient, error) {
connectionID := resourceConnectionHash(d)
c.mux.Lock()
defer c.mux.Unlock()
for {
c.mux.Lock()

client, ok := c.remoteClients[connectionID]
if ok {
if c.activeSessions[connectionID] >= c.maxSessions {
c.mux.Unlock()
continue
}
c.activeSessions[connectionID] += 1

return client, nil
}

client, ok := c.remoteClients[connectionID]
if ok {
client, err := RemoteClientFromResource(d)
if err != nil {
return nil, err
}

c.remoteClients[connectionID] = client
c.activeSessions[connectionID] = 1
return client, nil
}
}

func (c *apiClient) closeRemoteClient(d *schema.ResourceData) error {
connectionID := resourceConnectionHash(d)
c.mux.Lock()
defer c.mux.Unlock()

client, err := RemoteClientFromResource(d)
if err != nil {
return nil, err
c.activeSessions[connectionID] -= 1
if c.activeSessions[connectionID] == 0 {
client := c.remoteClients[connectionID]
delete(c.remoteClients, connectionID)
return client.Close()
}

c.remoteClients[connectionID] = client
return client, nil
return nil
}

func RemoteClientFromResource(d *schema.ResourceData) (*RemoteClient, error) {
Expand Down
4 changes: 4 additions & 0 deletions internal/provider/remote_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ func NewRemoteClient(host string, clientConfig ssh.ClientConfig) (*RemoteClient,
}, nil
}

func (c *RemoteClient) Close() error {
return c.sshClient.Close()
}

func (c *RemoteClient) GetSSHClient() *ssh.Client {
return c.sshClient
}
Expand Down
21 changes: 18 additions & 3 deletions internal/provider/resource_remotefile.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func resourceRemotefileCreate(ctx context.Context, d *schema.ResourceData, meta

client, err := meta.(*apiClient).getRemoteClient(d)
if err != nil {
return diag.Errorf(err.Error())
return diag.Errorf("error while opening remote client: %s", err.Error())
}

sudo, ok := d.GetOk("conn.0.sudo")
Expand All @@ -120,6 +120,11 @@ func resourceRemotefileCreate(ctx context.Context, d *schema.ResourceData, meta
}
}

err = meta.(*apiClient).closeRemoteClient(d)
if err != nil {
return diag.Errorf("error while closing remote client: %s", err.Error())
}

return diag.Diagnostics{}
}

Expand All @@ -128,7 +133,7 @@ func resourceRemotefileRead(ctx context.Context, d *schema.ResourceData, meta in

client, err := meta.(*apiClient).getRemoteClient(d)
if err != nil {
return diag.Errorf(err.Error())
return diag.Errorf("error while opening remote client: %s", err.Error())
}

sudo, ok := d.GetOk("conn.0.sudo")
Expand All @@ -152,6 +157,11 @@ func resourceRemotefileRead(ctx context.Context, d *schema.ResourceData, meta in
}
}

err = meta.(*apiClient).closeRemoteClient(d)
if err != nil {
return diag.Errorf("error while closing remote client: %s", err.Error())
}

return diag.Diagnostics{}
}

Expand All @@ -162,7 +172,7 @@ func resourceRemotefileUpdate(ctx context.Context, d *schema.ResourceData, meta
func resourceRemotefileDelete(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
client, err := meta.(*apiClient).getRemoteClient(d)
if err != nil {
return diag.Errorf(err.Error())
return diag.Errorf("error while opening remote client: %s", err.Error())
}

sudo, ok := d.GetOk("conn.0.sudo")
Expand All @@ -186,5 +196,10 @@ func resourceRemotefileDelete(ctx context.Context, d *schema.ResourceData, meta
}
}

err = meta.(*apiClient).closeRemoteClient(d)
if err != nil {
return diag.Errorf("error while closing remote client: %s", err.Error())
}

return diag.Diagnostics{}
}

0 comments on commit 95f6c52

Please sign in to comment.