Skip to content

Commit

Permalink
Updated storage impl
Browse files Browse the repository at this point in the history
  • Loading branch information
niemyjski committed Jan 4, 2024
1 parent ec684b2 commit 82d767f
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions src/Foundatio.Storage.SshNet/Storage/SshNetFileStorage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
using Renci.SshNet.Common;
using Renci.SshNet.Sftp;

namespace Foundatio.Storage;
namespace Foundatio.Storage;

public class SshNetFileStorage : IFileStorage {
private readonly ISftpClient _client;
Expand All @@ -40,10 +40,18 @@ public ISftpClient GetClient() {
return _client;
}

public async Task<Stream> GetFileStreamAsync(string path, CancellationToken cancellationToken = default) {
[Obsolete($"Use {nameof(GetFileStreamAsync)} with {nameof(FileAccess)} instead to define read or write behaviour of stream")]
public Task<Stream> GetFileStreamAsync(string path, CancellationToken cancellationToken = default)
=> GetFileStreamAsync(path, StreamMode.Read, cancellationToken);

public async Task<Stream> GetFileStreamAsync(string path, StreamMode streamMode, CancellationToken cancellationToken = default)

Check failure on line 47 in src/Foundatio.Storage.SshNet/Storage/SshNetFileStorage.cs

View workflow job for this annotation

GitHub Actions / build / build

The type or namespace name 'StreamMode' could not be found (are you missing a using directive or an assembly reference?)

Check failure on line 47 in src/Foundatio.Storage.SshNet/Storage/SshNetFileStorage.cs

View workflow job for this annotation

GitHub Actions / build / build

The type or namespace name 'StreamMode' could not be found (are you missing a using directive or an assembly reference?)
{
if (String.IsNullOrEmpty(path))
throw new ArgumentNullException(nameof(path));

if (streamMode is StreamMode.Write)
throw new NotSupportedException($"Stream mode {streamMode} is not supported.");

EnsureClientConnected();

string normalizedPath = NormalizePath(path);
Expand Down Expand Up @@ -85,7 +93,7 @@ public Task<bool> ExistsAsync(string path) {
throw new ArgumentNullException(nameof(path));

EnsureClientConnected();

string normalizedPath = NormalizePath(path);
_logger.LogTrace("Checking if {Path} exists", normalizedPath);
return Task.FromResult(_client.Exists(normalizedPath));
Expand All @@ -96,10 +104,10 @@ public async Task<bool> SaveFileAsync(string path, Stream stream, CancellationTo
throw new ArgumentNullException(nameof(path));
if (stream == null)
throw new ArgumentNullException(nameof(stream));

string normalizedPath = NormalizePath(path);
_logger.LogTrace("Saving {Path}", normalizedPath);

EnsureClientConnected();

try {
Expand All @@ -108,7 +116,7 @@ public async Task<bool> SaveFileAsync(string path, Stream stream, CancellationTo
} catch (SftpPathNotFoundException ex) {
_logger.LogDebug(ex, "Error saving {Path}: Attempting to create directory", normalizedPath);
CreateDirectory(normalizedPath);

_logger.LogTrace("Saving {Path}", normalizedPath);
await using var sftpFileStream = await _client.OpenAsync(normalizedPath, FileMode.OpenOrCreate, FileAccess.Write, cancellationToken).AnyContext();
await stream.CopyToAsync(sftpFileStream, cancellationToken).AnyContext();
Expand Down Expand Up @@ -139,7 +147,7 @@ public async Task<bool> RenameFileAsync(string path, string newPath, Cancellatio
} catch (SftpPathNotFoundException ex) {
_logger.LogDebug(ex, "Error renaming {Path} to {NewPath}: Attempting to create directory", normalizedPath, normalizedNewPath);
CreateDirectory(normalizedNewPath);

_logger.LogTrace("Renaming {Path} to {NewPath}", normalizedPath, normalizedNewPath);
await _client.RenameFileAsync(normalizedPath, normalizedNewPath, cancellationToken).AnyContext();
} catch (Exception ex) {
Expand All @@ -161,7 +169,7 @@ public async Task<bool> CopyFileAsync(string path, string targetPath, Cancellati
_logger.LogInformation("Copying {Path} to {TargetPath}", normalizedPath, normalizedTargetPath);

try {
await using var stream = await GetFileStreamAsync(normalizedPath, cancellationToken).AnyContext();
await using var stream = await GetFileStreamAsync(normalizedPath, StreamMode.Read, cancellationToken).AnyContext();
if (stream == null)
return false;

Expand Down Expand Up @@ -226,25 +234,25 @@ private void CreateDirectory(string path) {
currentDirectory = String.IsNullOrEmpty(currentDirectory)
? segment
: String.Concat(currentDirectory, "/", segment);
if (_client.Exists(currentDirectory))

if (_client.Exists(currentDirectory))
continue;

_logger.LogInformation("Creating {Directory} directory", directory);
_client.CreateDirectory(currentDirectory);
}
}

private async Task<int> DeleteDirectoryAsync(string path, bool includeSelf, CancellationToken cancellationToken = default) {
int count = 0;

string directory = NormalizePath(path);
_logger.LogInformation("Deleting {Directory} directory", directory);

await foreach (var file in _client.ListDirectoryAsync(directory, cancellationToken).AnyContext()) {
if (file.Name is "." or "..")
continue;

if (file.IsDirectory) {
count += await DeleteDirectoryAsync(file.FullName, true, cancellationToken);
} else {
Expand Down Expand Up @@ -302,12 +310,12 @@ private async Task<List<FileSpec>> GetFileListAsync(string searchPattern = null,

// NOTE: This could be very expensive the larger the directory structure you have as we aren't efficiently doing paging.
int? recordsToReturn = limit.HasValue ? skip.GetValueOrDefault() * limit + limit : null;

_logger.LogTrace(
s => s.Property("SearchPattern", searchPattern).Property("Limit", limit).Property("Skip", skip),
s => s.Property("SearchPattern", searchPattern).Property("Limit", limit).Property("Skip", skip),
"Getting file list recursively matching {Prefix} and {Pattern}...", criteria.Prefix, criteria.Pattern
);

await GetFileListRecursivelyAsync(criteria.Prefix, criteria.Pattern, list, recordsToReturn, cancellationToken).AnyContext();

if (skip.HasValue)
Expand Down Expand Up @@ -415,9 +423,9 @@ private ConnectionInfo CreateConnectionInfo(SshNetFileStorageOptions options) {
}

private void EnsureClientConnected() {
if (_client.IsConnected)
if (_client.IsConnected)
return;

_logger.LogTrace("Connecting to {Host}:{Port}", _client.ConnectionInfo.Host, _client.ConnectionInfo.Port);
_client.Connect();
_logger.LogTrace("Connected to {Host}:{Port} in {WorkingDirectory}", _client.ConnectionInfo.Host, _client.ConnectionInfo.Port, _client.WorkingDirectory);
Expand Down

0 comments on commit 82d767f

Please sign in to comment.