Skip to content

Commit

Permalink
Fix misbehaviour in resources and file permissions; Stabilize tests a…
Browse files Browse the repository at this point in the history
…nd improve coverage
  • Loading branch information
steve-hb committed Feb 21, 2025
1 parent d7d64dc commit 44b704f
Show file tree
Hide file tree
Showing 11 changed files with 200 additions and 71 deletions.
11 changes: 0 additions & 11 deletions internal/provider/data/test/file_data_source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,6 @@ func TestAccFileDataSource(t *testing.T) {

func testAccFileDataSourceConfig(path string) string {
return fmt.Sprintf(`
terraform {
required_providers {
ssh = {
source = "askrella/ssh"
version = "0.1.0"
}
}
}
provider "askrella-ssh" {}
data "ssh_file_info" "test" {
ssh = {
host = "localhost"
Expand Down
38 changes: 22 additions & 16 deletions internal/provider/resource/directory_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func (r *DirectoryResource) Create(ctx context.Context, req resource.CreateReque
}
defer client.Close()

permissions := parsePermissions(plan.Permissions.ValueString())
permissions := ssh.ParsePermissions(plan.Permissions.ValueString())

if exists, _ := client.Exists(ctx, plan.Path.ValueString()); !exists {
err = client.CreateDirectory(ctx, plan.Path.ValueString(), os.FileMode(permissions))
Expand Down Expand Up @@ -334,23 +334,11 @@ func (r *DirectoryResource) Update(ctx context.Context, req resource.UpdateReque
}
defer client.Close()

exists, err := client.Exists(ctx, plan.Path.ValueString())
if err != nil {
resp.Diagnostics.AddError(
"Error determining if directory exists",
fmt.Sprintf("Could determine directory existence: %s", err),
)
return
}
if !exists {
resp.State.RemoveResource(ctx)
return
}

permissions := parsePermissions(plan.Permissions.ValueString())
permissions := ssh.ParsePermissions(plan.Permissions.ValueString())
wantedFileMode := os.FileMode(permissions)

if exists, _ := client.Exists(ctx, plan.Path.ValueString()); !exists {
err = client.CreateDirectory(ctx, plan.Path.ValueString(), os.FileMode(permissions))
err = client.CreateDirectory(ctx, plan.Path.ValueString(), wantedFileMode)
if err != nil {
resp.Diagnostics.AddError(
"Error updating directory",
Expand All @@ -360,6 +348,24 @@ func (r *DirectoryResource) Update(ctx context.Context, req resource.UpdateReque
}
}

fileMode, err := client.GetFileMode(ctx, plan.Path.ValueString())
if err != nil {
resp.Diagnostics.AddError(
"Error retrieving permissions",
fmt.Sprintf("Could not retrieve permissions: %s", err),
)
}
if fileMode != wantedFileMode {
err := client.SetFileMode(ctx, plan.Path.ValueString(), wantedFileMode)
if err != nil {
resp.Diagnostics.AddError(
"Error updating permissions",
fmt.Sprintf("Could not set permissions: %s", err),
)
return
}
}

// Set ownership if specified
if !plan.Owner.IsNull() || !plan.Group.IsNull() {
err = client.SetFileOwnership(ctx, plan.Path.ValueString(), &ssh.FileOwnership{
Expand Down
53 changes: 39 additions & 14 deletions internal/provider/resource/file_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,40 @@ func (r *FileResource) Create(ctx context.Context, req resource.CreateRequest, r
return
}
if exists {
resp.State.RemoveResource(ctx)
return
content, err := client.ReadFile(ctx, plan.Path.ValueString())
if err != nil {
resp.Diagnostics.AddError(
"Error checking file content",
fmt.Sprintf("Could not read file content: %s", err),
)
return
}

// When content does not match the desired state, delete the file and pretend it doesn't exist (anymore)
if content != plan.Content.ValueString() {
err := client.DeleteFile(ctx, plan.Path.ValueString())
if err != nil {
resp.Diagnostics.AddError(
"Error recreating file",
fmt.Sprintf("Could delete file after content mismatch: %s", err),
)
return
}
exists = false
}
}

permissions := parsePermissions(plan.Permissions.ValueString())
permissions := ssh.ParsePermissions(plan.Permissions.ValueString())

err = client.CreateFile(ctx, plan.Path.ValueString(), plan.Content.ValueString(), os.FileMode(permissions))
if err != nil {
resp.Diagnostics.AddError(
"Error creating file",
fmt.Sprintf("Could not create file: %s", err),
)
return
if !exists {
err = client.CreateFile(ctx, plan.Path.ValueString(), plan.Content.ValueString(), os.FileMode(permissions))
if err != nil {
resp.Diagnostics.AddError(
"Error creating file",
fmt.Sprintf("Could not create file: %s", err),
)
return
}
}

// Set ownership if specified
Expand Down Expand Up @@ -368,12 +389,16 @@ func (r *FileResource) Update(ctx context.Context, req resource.UpdateRequest, r
)
return
}
if !exists {
resp.State.RemoveResource(ctx)
return
if exists {
if err := client.DeleteFile(ctx, plan.Path.ValueString()); err != nil {
resp.Diagnostics.AddError(
"Error updating file",
fmt.Sprintf("Could not recreate file: %s", err),
)
}
}

permissions := parsePermissions(plan.Permissions.ValueString())
permissions := ssh.ParsePermissions(plan.Permissions.ValueString())

err = client.CreateFile(ctx, plan.Path.ValueString(), plan.Content.ValueString(), os.FileMode(permissions))
if err != nil {
Expand Down
18 changes: 5 additions & 13 deletions internal/provider/resource/test/directory_resource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,6 @@ import (
func TestAccDirectoryResource(t *testing.T) {
t.Parallel()

// Setup SSH client for verification
sshConfig := ssh.SSHConfig{
Host: "localhost",
Port: 2222,
Username: "testuser",
Password: "testpass",
}

client, err := ssh.NewSSHClient(context.Background(), sshConfig)
require.NoError(t, err)
defer client.Close()
Expand Down Expand Up @@ -62,7 +54,7 @@ func TestAccDirectoryResource(t *testing.T) {
return fmt.Errorf("failed to get directory permissions: %v", err)
}
if mode != os.FileMode(0755) {
return fmt.Errorf("unexpected permissions for creation: got %o, want 0755", mode)
return fmt.Errorf("unexpected permissions for creation: got octal %o, want 0755", mode)
}

// Verify ownership
Expand All @@ -83,10 +75,10 @@ func TestAccDirectoryResource(t *testing.T) {
},
// Update testing
{
Config: testAccDirectoryResourceConfig(dirName, "0775", "testuser", "testuser"),
Config: testAccDirectoryResourceConfig(dirName, "0600", "testuser", "testuser"),
Check: resource.ComposeAggregateTestCheckFunc(
resource.TestCheckResourceAttr("ssh_directory.test", "path", testDirPath),
resource.TestCheckResourceAttr("ssh_directory.test", "permissions", "0775"),
resource.TestCheckResourceAttr("ssh_directory.test", "permissions", "0600"),
resource.TestCheckResourceAttr("ssh_directory.test", "owner", "testuser"),
resource.TestCheckResourceAttr("ssh_directory.test", "group", "testuser"),
resource.TestCheckResourceAttr("ssh_directory.test", "ssh.host", "localhost"),
Expand All @@ -107,8 +99,8 @@ func TestAccDirectoryResource(t *testing.T) {
if err != nil {
return fmt.Errorf("failed to get directory permissions: %v", err)
}
if mode != os.FileMode(0775) {
return fmt.Errorf("unexpected permissions for updating: got %o, want 0775", mode)
if mode != os.FileMode(0600) {
return fmt.Errorf("unexpected permissions for updating: got %o, want 0600", mode)
}

// Verify ownership
Expand Down
8 changes: 0 additions & 8 deletions internal/provider/resource/test/file_resource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,6 @@ import (
func TestAccFileResource(t *testing.T) {
t.Parallel()

// Setup SSH client for verification
sshConfig := ssh.SSHConfig{
Host: "localhost",
Port: 2222,
Username: "testuser",
Password: "testpass",
}

client, err := ssh.NewSSHClient(context.Background(), sshConfig)
require.NoError(t, err)
defer client.Close()
Expand Down
8 changes: 8 additions & 0 deletions internal/provider/resource/test/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package test

import (
"github.com/askrella/askrella-ssh-provider/internal/provider"
"github.com/askrella/askrella-ssh-provider/internal/provider/ssh"
"github.com/hashicorp/terraform-plugin-framework/providerserver"
"github.com/hashicorp/terraform-plugin-go/tfprotov6"
)
Expand All @@ -10,4 +11,11 @@ var (
testAccProtoV6ProviderFactories = map[string]func() (tfprotov6.ProviderServer, error){
"ssh": providerserver.NewProtocol6WithError(provider.New("test")()),
}

sshConfig = ssh.SSHConfig{
Host: "::1",
Port: 2222,
Username: "testuser",
Password: "testpass",
}
)
29 changes: 29 additions & 0 deletions internal/provider/ssh/host_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package ssh

import (
. "github.com/onsi/gomega"
"net"
"net/url"
"testing"
)

func TestHostParsing(t *testing.T) {
RegisterTestingT(t)
addresses := []string{
"127.0.0.1", "localhost", "2a02:4f8:d014:b2f2::1",
}

for _, addr := range addresses {
t.Run(addr, func(t *testing.T) {
RegisterTestingT(t)

ip := net.ParseIP(addr)
if ip != nil {
return
}

_, err := url.Parse(addr)
Expect(err).ToNot(HaveOccurred())
})
}
}
14 changes: 14 additions & 0 deletions internal/provider/ssh/ssh_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,20 @@ func (c *SSHClient) GetFileMode(ctx context.Context, path string) (os.FileMode,
return info.Mode().Perm(), nil
}

// GetFileMode gets the permissions of a file or directory
func (c *SSHClient) SetFileMode(ctx context.Context, path string, mode os.FileMode) error {
ctx, span := otel.Tracer("ssh-provider").Start(ctx, "SetFileMode")
defer span.End()

err := c.SftpClient.Chmod(path, mode)
if err != nil {
c.logger.WithContext(ctx).WithError(err).Error("Failed to set file mode")
return fmt.Errorf("failed to set file mode: %w", err)
}

return nil
}

// GetFileOwnership gets the user and group ownership of a file or directory
func (c *SSHClient) GetFileOwnership(ctx context.Context, path string) (*FileOwnership, error) {
ctx, span := otel.Tracer("ssh-provider").Start(ctx, "GetFileOwnership")
Expand Down
61 changes: 54 additions & 7 deletions internal/provider/ssh/ssh_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,68 @@ package ssh
import (
"context"
"crypto/rand"
. "github.com/onsi/gomega"
"os"
"path"
"testing"

. "github.com/onsi/gomega"
)

var sshConfig = SSHConfig{
Host: "localhost",
Port: 2222,
Username: "testuser",
Password: "testpass",
}

func TestFilePermissions(t *testing.T) {
RegisterTestingT(t)

client, err := NewSSHClient(context.Background(), sshConfig)
Expect(err).ToNot(HaveOccurred())
ctx := context.Background()
basePath := "/home/testuser/ssh_test_" + rand.Text()

testCases := []struct {
name string
filePath string
content string
permissions os.FileMode
}{
{
name: "Test File Permissions 0777",
filePath: basePath + "_1",
content: "Hello World",
permissions: 0777,
},
{
name: "Test File Permissions 0644",
filePath: basePath + "_2",
content: "Hello World",
permissions: 0644,
},
{
name: "Test File Permissions 0600",
filePath: basePath + "_3",
content: "Hello World",
permissions: 0600,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
RegisterTestingT(t)

Expect(client.CreateFile(ctx, tc.filePath, tc.content, tc.permissions)).Should(Succeed())
Expect(client.GetFileMode(ctx, tc.filePath)).To(BeEquivalentTo(tc.permissions))
})
}
}

func TestDirectoryOperations(t *testing.T) {
RegisterTestingT(t)

client, err := NewSSHClient(context.Background(), SSHConfig{
Host: "localhost",
Port: 2222,
Username: "testuser",
Password: "testpass",
})
client, err := NewSSHClient(context.Background(), sshConfig)
Expect(err).ToNot(HaveOccurred())

basePath := "/home/testuser/ssh_test_" + rand.Text()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package resource
package ssh

import "strconv"

func parsePermissions(perms string) uint32 {
func ParsePermissions(perms string) uint32 {
if perms == "" {
return 0644
}
Expand Down
27 changes: 27 additions & 0 deletions internal/provider/ssh/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package ssh

import (
. "github.com/onsi/gomega"
"testing"
)

func TestParsePermissions(t *testing.T) {
RegisterTestingT(t)

tests := []struct {
str string
expected uint32
}{
{"755", 0755},
{"0755", 0755},
{"777", 0777},
{"0777", 0777},
{"0600", 0600},
{"600", 0600},
}
for _, test := range tests {
t.Run(test.str, func(t *testing.T) {
Expect(ParsePermissions(test.str)).To(Equal(test.expected))
})
}
}

0 comments on commit 44b704f

Please sign in to comment.