Skip to content

Commit

Permalink
sftpfs.fileSystem.Close closes connection when refcound gets 0
Browse files Browse the repository at this point in the history
  • Loading branch information
ungerik committed Dec 21, 2023
1 parent 3332eb9 commit afb9f5e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
5 changes: 4 additions & 1 deletion ftpfs/ftpfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,10 @@ func (f *fileSystem) Close() error {
if f.conn == nil {
return nil // already closed
}
fs.Unregister(f)
count := fs.Unregister(f)
if count > 1 {
return nil // still referenced
}
err := f.conn.Quit()
f.conn = nil
return err
Expand Down
14 changes: 9 additions & 5 deletions sftpfs/sftpfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,15 @@ func DialAndRegister(ctx context.Context, address string, loginCallback LoginCal
// is already registered. If not, then a new connection is dialed and registered.
// The returned free function has to be called to decrease the file system's
// reference count and close it when the reference count reaches 0.
func EnsureRegistered(ctx context.Context, address string, loginCallback LoginCallback, hostKeyCallback ssh.HostKeyCallback) (free func(), err error) {
func EnsureRegistered(ctx context.Context, address string, loginCallback LoginCallback, hostKeyCallback ssh.HostKeyCallback) (free func() error, err error) {
u, username, password, prefix, err := prepareDial(address, loginCallback, hostKeyCallback)
if err != nil {
return nil, err
}
f := fs.GetFileSystemByPrefixOrNil(prefix)
if f != nil {
fs.Register(f) // Increase ref count
return func() { fs.Unregister(f) }, nil
return func() error { fs.Unregister(f); return nil }, nil
}

client, err := dial(ctx, u.Host, username, password, hostKeyCallback)
Expand All @@ -167,7 +167,7 @@ func EnsureRegistered(ctx context.Context, address string, loginCallback LoginCa
prefix: prefix,
}
fs.Register(f)
return func() { fs.Unregister(f) }, nil
return func() error { return f.Close() }, nil
}

func dial(ctx context.Context, host, user, password string, hostKeyCallback ssh.HostKeyCallback) (*sftp.Client, error) {
Expand Down Expand Up @@ -436,7 +436,11 @@ func (f *fileSystem) Close() error {
if f.client == nil {
return nil // already closed
}
fs.Unregister(f)
count := fs.Unregister(f)
if count > 1 {
return nil // still referenced
}
err := f.client.Close()
f.client = nil
return nil
return err
}

0 comments on commit afb9f5e

Please sign in to comment.