diff --git a/internal/pkg/agent/cmd/container_init_linux.go b/internal/pkg/agent/cmd/container_init_linux.go index f723a7af7c6..1cfeda095e4 100644 --- a/internal/pkg/agent/cmd/container_init_linux.go +++ b/internal/pkg/agent/cmd/container_init_linux.go @@ -162,12 +162,8 @@ func updateFileCapsFromBoundingSet(executablePath string) (updated bool, err err return false, fmt.Errorf("failed to chown %s: %w", executablePath, err) } - var fileSet interface { - SetFile(pathString string) error - } - // create a new set based on the capabilities of Bounding set - fileSet, err = cap.FromText(capsText) + fileSet, err := cap.FromText(capsText) if err != nil { return false, fmt.Errorf("failed to parse caps text: %w", err) } diff --git a/internal/pkg/agent/cmd/container_init_test.go b/internal/pkg/agent/cmd/container_init_test.go index dcb1adcebd8..57c2b1b740a 100644 --- a/internal/pkg/agent/cmd/container_init_test.go +++ b/internal/pkg/agent/cmd/container_init_test.go @@ -9,6 +9,7 @@ package cmd import ( "os" "path/filepath" + "syscall" "testing" "github.com/stretchr/testify/require" @@ -26,12 +27,8 @@ func Test_chownPaths(t *testing.T) { defer os.RemoveAll(secondParentDir) childDir := filepath.Join(secondParentDir, "child") - err = os.MkdirAll(childDir, 0o777) - require.NoError(t, err) childChildDir := filepath.Join(childDir, "child") - err = os.MkdirAll(childDir, 0o777) - require.NoError(t, err) pathsToChown := distinctPaths{} pathsToChown.addPath(childDir) @@ -45,24 +42,80 @@ func Test_chownPaths(t *testing.T) { require.NoError(t, err) } +func Test_updateFileCapsFromBoundingSet(t *testing.T) { + if os.Geteuid() == 0 { + t.Skip("this test requires non-root user") + return + } + + tmpDir, err := os.MkdirTemp("", "test_chown") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + executable := filepath.Join(tmpDir, "test_exec") + + err = os.WriteFile(executable, []byte{}, 0o7777) + require.NoError(t, err) + + updated, err := updateFileCapsFromBoundingSet(executable) + require.ErrorIs(t, err, syscall.EPERM) + require.False(t, updated) +} + func Test_getMissingBoundingCapsText(t *testing.T) { tc := []struct { - name string - fileCaps []cap.Value - boundCaps []cap.Value - capsText string + name string + fileCaps []cap.Value + fileCapsErr error + boundCaps []cap.Value + boundCapsErr error + capsText string + expectedErr error }{ { - name: "no missing caps", - fileCaps: []cap.Value{cap.CHOWN, cap.SETPCAP}, - boundCaps: []cap.Value{cap.CHOWN, cap.SETPCAP}, - capsText: "", + name: "no missing caps", + fileCaps: []cap.Value{cap.CHOWN, cap.SETPCAP}, + fileCapsErr: nil, + boundCaps: []cap.Value{cap.CHOWN, cap.SETPCAP}, + boundCapsErr: nil, + capsText: "", + expectedErr: nil, + }, + { + name: "missing caps", + fileCaps: []cap.Value{cap.CHOWN, cap.SETPCAP}, + fileCapsErr: nil, + boundCaps: []cap.Value{cap.CHOWN, cap.SETPCAP, cap.DAC_OVERRIDE}, + boundCapsErr: nil, + capsText: "cap_chown,cap_dac_override,cap_setpcap=eip", + expectedErr: nil, + }, + { + name: "no data err", + fileCaps: nil, + fileCapsErr: syscall.ENODATA, + boundCaps: []cap.Value{cap.CHOWN, cap.SETPCAP, cap.DAC_OVERRIDE}, + boundCapsErr: nil, + capsText: "cap_chown,cap_dac_override,cap_setpcap=eip", + expectedErr: nil, }, { - name: "missing caps", - fileCaps: []cap.Value{cap.CHOWN, cap.SETPCAP}, - boundCaps: []cap.Value{cap.CHOWN, cap.SETPCAP, cap.DAC_OVERRIDE}, - capsText: "cap_chown,cap_dac_override,cap_setpcap=eip", + name: "file caps permission err", + fileCaps: nil, + fileCapsErr: syscall.EPERM, + boundCaps: []cap.Value{cap.CHOWN, cap.SETPCAP, cap.DAC_OVERRIDE}, + boundCapsErr: nil, + capsText: "", + expectedErr: syscall.EPERM, + }, + { + name: "bound caps permission err", + fileCaps: nil, + fileCapsErr: nil, + boundCaps: []cap.Value{cap.CHOWN, cap.SETPCAP, cap.DAC_OVERRIDE}, + boundCapsErr: syscall.EPERM, + capsText: "", + expectedErr: syscall.EPERM, }, } @@ -74,6 +127,10 @@ func Test_getMissingBoundingCapsText(t *testing.T) { for _, tt := range tc { t.Run(tt.name, func(t *testing.T) { capBound = func(val cap.Value) (bool, error) { + if tt.boundCapsErr != nil { + return false, tt.boundCapsErr + } + for _, boundCap := range tt.boundCaps { if boundCap == val { return true, nil @@ -82,17 +139,23 @@ func Test_getMissingBoundingCapsText(t *testing.T) { return false, nil } capGetFile = func(path string) (*cap.Set, error) { - set := cap.NewSet() + if tt.fileCapsErr != nil { + return nil, tt.fileCapsErr + } + set := cap.NewSet() if err := set.SetFlag(cap.Effective, true, tt.fileCaps...); err != nil { return nil, err } - return set, nil } capsText, err := getMissingBoundingCapsText("non_existent") - require.NoError(t, err) + if tt.expectedErr != nil { + require.ErrorIs(t, err, tt.expectedErr) + } else { + require.NoError(t, err) + } require.Equal(t, tt.capsText, capsText) }) }